1414import inspect
1515import sys
1616import warnings
17+ from os import path
1718
1819import six
1920import tensorflow as tf
2021from sqlflow_submitter .pai import model
2122from sqlflow_submitter .seeding import get_tf_random_seed
2223
24+ from ..model_metadata import save_model_metadata
2325from . import metrics
2426from .diag import check_and_load_estimator
2527from .get_tf_version import tf_is_version2
3335def keras_train_and_save (estimator , model_params , save , is_pai , FLAGS ,
3436 train_dataset_fn , val_dataset_fn , label_meta , epochs ,
3537 verbose , metric_names , validation_steps ,
36- load_pretrained_model ):
38+ load_pretrained_model , model_meta ):
3739 print ("Start training using keras model..." )
3840 # remove optimizer param from model_params and use it when call "compile()"
3941 optimizer = None
@@ -171,10 +173,15 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
171173 tf .feature_column .make_parse_example_spec (all_feature_columns ))
172174 export_path = keras_estimator .export_saved_model (
173175 save , serving_input_fn )
176+
174177 # write the path under current directory
178+ export_path_str = str (export_path .decode ("utf-8" ))
175179 with open ("exported_path" , "w" ) as fn :
176- fn .write (str (export_path .decode ("utf-8" )))
177- print ("Done training, model exported to: %s" % export_path )
180+ fn .write (export_path_str )
181+ # write model metadata to model_meta.json
182+ save_model_metadata (path .join (export_path_str , "model_meta.json" ),
183+ model_meta )
184+ print ("Done training, model exported to: %s" % export_path_str )
178185 return
179186
180187 if hasattr (classifier , 'sqlflow_train_loop' ):
@@ -200,25 +207,29 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
200207 epochs = epochs if epochs else
201208 classifier .default_training_epochs (),
202209 verbose = verbose )
203- train_keys = []
204- val_keys = []
210+ train_metrics = dict ()
211+ val_metrics = dict ()
205212 for k in history .history .keys ():
206213 if k .startswith ("val_" ):
207- val_keys . append ( k )
214+ val_metrics [ k ] = float ( history . history [ k ][ - 1 ] )
208215 else :
209- train_keys . append ( k )
216+ train_metrics [ k ] = float ( history . history [ k ][ - 1 ] )
210217 print ("====== Result for training set: ======" )
211- for k in train_keys :
212- print ("%s: %s" % (k , history . history [ k ][ - 1 ] ))
218+ for k , v in train_metrics . items () :
219+ print ("%s: %s" % (k , v ))
213220 print ("====== Result for validation set: ======" )
214- for k in val_keys :
215- print ("%s: %s" % (k , history .history [k ][- 1 ]))
221+ for k , v in val_metrics .items ():
222+ print ("%s: %s" % (k , v ))
223+ model_meta ["evaluation" ] = val_metrics
216224
217225 try :
218226 classifier .save_weights (save , save_format = "h5" )
227+ # write model metadata to model_meta.json
228+ save_model_metadata ("model_meta.json" , model_meta )
219229 if is_pai :
220230 print ("saving keras model to: %s" % FLAGS .sqlflow_oss_modeldir )
221231 model .save_file (FLAGS .sqlflow_oss_modeldir , save )
232+ model .save_file (FLAGS .sqlflow_oss_modeldir , "model_meta.json" )
222233 except :
223234 if has_none_optimizer :
224235 warnings .warn ("Saving model with None optimizer fails" )
0 commit comments