@@ -66,7 +66,7 @@ type requirementsFiller struct {
6666const tfImportsText = `
6767import tensorflow as tf
6868from runtime.tensorflow import is_tf_estimator
69- from tensorflow.estimator import DNNClassifier, DNNRegressor, LinearClassifier, LinearRegressor, BoostedTreesClassifier, BoostedTreesRegressor, DNNLinearCombinedClassifier, DNNLinearCombinedRegressor
69+ from runtime.import_model import import_model
7070try:
7171 from runtime import oss
7272 from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
@@ -79,7 +79,7 @@ const tfLoadModelTmplText = tfImportsText + `
7979FLAGS = define_tf_flags()
8080set_oss_environs(FLAGS)
8181
82- estimator = {{.Estimator}}
82+ estimator = import_model(''' {{.Estimator}}''')
8383is_estimator = is_tf_estimator(estimator)
8484
8585# Keras single node is using h5 format to save the model, no need to deal with export model format.
9595const tfSaveModelTmplText = tfImportsText + `
9696import types
9797
98- estimator = {{.Estimator}}
98+ estimator = import_model(''' {{.Estimator}}''')
9999is_estimator = is_tf_estimator(estimator)
100100
101101# Keras single node is using h5 format to save the model, no need to deal with export model format.
@@ -173,7 +173,7 @@ feature_columns = eval(feature_columns_code)
173173# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
174174# because predicting do not need these parameters.
175175
176- is_estimator = is_tf_estimator(eval (estimator))
176+ is_estimator = is_tf_estimator(import_model (estimator))
177177
178178# Keras single node is using h5 format to save the model, no need to deal with export model format.
179179# Keras distributed mode will use estimator, so this is also needed.
@@ -233,7 +233,7 @@ feature_columns = eval(feature_columns_code)
233233# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
234234# because predicting do not need these parameters.
235235
236- is_estimator = is_tf_estimator(eval (estimator))
236+ is_estimator = is_tf_estimator(import_model (estimator))
237237
238238# Keras single node is using h5 format to save the model, no need to deal with export model format.
239239# Keras distributed mode will use estimator, so this is also needed.
@@ -296,7 +296,7 @@ feature_columns = eval(feature_columns_code)
296296# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
297297# because predicting do not need these parameters.
298298
299- is_estimator = is_tf_estimator(eval (estimator))
299+ is_estimator = is_tf_estimator(import_model (estimator))
300300
301301# Keras single node is using h5 format to save the model, no need to deal with export model format.
302302# Keras distributed mode will use estimator, so this is also needed.
0 commit comments