Skip to content

Commit c5300f1

Browse files
authored
Unify eval estimator code (#2757)
* unify eval estimator code * update * update * polish import_model_module and doc * polish doc * fix ut
1 parent 572739d commit c5300f1

15 files changed

Lines changed: 109 additions & 119 deletions

go/codegen/pai/template_tf.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type requirementsFiller struct {
6666
const tfImportsText = `
6767
import tensorflow as tf
6868
from runtime.tensorflow import is_tf_estimator
69-
from tensorflow.estimator import DNNClassifier, DNNRegressor, LinearClassifier, LinearRegressor, BoostedTreesClassifier, BoostedTreesRegressor, DNNLinearCombinedClassifier, DNNLinearCombinedRegressor
69+
from runtime.import_model import import_model
7070
try:
7171
from runtime import oss
7272
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
@@ -79,7 +79,7 @@ const tfLoadModelTmplText = tfImportsText + `
7979
FLAGS = define_tf_flags()
8080
set_oss_environs(FLAGS)
8181
82-
estimator = {{.Estimator}}
82+
estimator = import_model('''{{.Estimator}}''')
8383
is_estimator = is_tf_estimator(estimator)
8484
8585
# Keras single node is using h5 format to save the model, no need to deal with export model format.
@@ -95,7 +95,7 @@ else:
9595
const tfSaveModelTmplText = tfImportsText + `
9696
import types
9797
98-
estimator = {{.Estimator}}
98+
estimator = import_model('''{{.Estimator}}''')
9999
is_estimator = is_tf_estimator(estimator)
100100
101101
# Keras single node is using h5 format to save the model, no need to deal with export model format.
@@ -173,7 +173,7 @@ feature_columns = eval(feature_columns_code)
173173
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
174174
# because predicting do not need these parameters.
175175
176-
is_estimator = is_tf_estimator(eval(estimator))
176+
is_estimator = is_tf_estimator(import_model(estimator))
177177
178178
# Keras single node is using h5 format to save the model, no need to deal with export model format.
179179
# Keras distributed mode will use estimator, so this is also needed.
@@ -233,7 +233,7 @@ feature_columns = eval(feature_columns_code)
233233
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
234234
# because predicting do not need these parameters.
235235
236-
is_estimator = is_tf_estimator(eval(estimator))
236+
is_estimator = is_tf_estimator(import_model(estimator))
237237
238238
# Keras single node is using h5 format to save the model, no need to deal with export model format.
239239
# Keras distributed mode will use estimator, so this is also needed.
@@ -296,7 +296,7 @@ feature_columns = eval(feature_columns_code)
296296
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
297297
# because predicting do not need these parameters.
298298
299-
is_estimator = is_tf_estimator(eval(estimator))
299+
is_estimator = is_tf_estimator(import_model(estimator))
300300
301301
# Keras single node is using h5 format to save the model, no need to deal with export model format.
302302
# Keras distributed mode will use estimator, so this is also needed.

python/runtime/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,3 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
14-
from runtime.import_custom_models import import_model_def

python/runtime/import_custom_models.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

python/runtime/import_model.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
try:
15+
import sqlflow_models # noqa: F401
16+
except: # noqa: E722
17+
pass
18+
19+
from tensorflow.estimator import BoostedTreesClassifier # noqa: F401
20+
from tensorflow.estimator import BoostedTreesRegressor # noqa: F401
21+
from tensorflow.estimator import DNNClassifier # noqa: F401
22+
from tensorflow.estimator import DNNLinearCombinedClassifier # noqa: F401
23+
from tensorflow.estimator import DNNLinearCombinedRegressor # noqa: F401
24+
from tensorflow.estimator import DNNRegressor # noqa: F401
25+
from tensorflow.estimator import LinearClassifier # noqa: F401
26+
from tensorflow.estimator import LinearRegressor # noqa: F401
27+
28+
29+
def import_model_module(model, namespace):
30+
"""
31+
Import the model module into namespace. For example,
32+
If model = "my_model_module.my_model", "my_model_module"
33+
would be imported into namespace.
34+
35+
Args:
36+
model (str): the model name.
37+
namespace (dict): the namespace to be imported into.
38+
39+
Returns:
40+
None.
41+
"""
42+
# try import the custom model's python package, if the estimator is of
43+
# format: my_model_package.MyModel
44+
model_name_parts = model.split(".")
45+
if len(model_name_parts) == 2:
46+
module = model_name_parts[0]
47+
if module and module.lower() not in ['xgboost', 'sqlflow_models']:
48+
try:
49+
namespace[module] = __import__(module)
50+
except Exception as e:
51+
print("failed to import %s: %s" % (module, e))
52+
53+
54+
def import_model(model):
55+
"""
56+
Import the model class or function from the given model name.
57+
58+
Args:
59+
model (str): the model name.
60+
61+
Returns:
62+
An imported model class or function.
63+
"""
64+
import_model_module(model, globals())
65+
return eval(model)

python/runtime/pai/tensorflow/evaluate.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import runtime
2020
import tensorflow as tf
2121
from runtime import oss
22+
from runtime.import_model import import_model
2223
from runtime.pai.pai_distributed import define_tf_flags
2324
from runtime.tensorflow import is_tf_estimator
2425
from runtime.tensorflow.evaluate import (estimator_evaluate, keras_evaluate,
@@ -27,11 +28,6 @@
2728
from runtime.tensorflow.keras_with_feature_column_input import \
2829
init_model_with_feature_column
2930
from runtime.tensorflow.set_log_level import set_log_level
30-
from tensorflow.estimator import (BoostedTreesClassifier,
31-
BoostedTreesRegressor, DNNClassifier,
32-
DNNLinearCombinedClassifier,
33-
DNNLinearCombinedRegressor, DNNRegressor,
34-
LinearClassifier, LinearRegressor)
3531

3632
try:
3733
tf.enable_eager_execution()
@@ -66,7 +62,7 @@ def evaluate(datasource, select, data_table, result_table, oss_model_path,
6662
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
6763
# because predicting do not need these parameters.
6864

69-
is_estimator = is_tf_estimator(eval(estimator))
65+
is_estimator = is_tf_estimator(import_model(estimator))
7066

7167
# Keras single node is using h5 format to save the model, no need to deal with export model format.
7268
# Keras distributed mode will use estimator, so this is also needed.
@@ -110,8 +106,7 @@ def _evaluate(datasource,
110106
validation_steps=None,
111107
verbose=0,
112108
pai_table=""):
113-
runtime.import_model_def(estimator_string, globals())
114-
estimator_cls = eval(estimator_string)
109+
estimator_cls = import_model(estimator_string)
115110
is_estimator = is_tf_estimator(estimator_cls)
116111
set_log_level(verbose, is_estimator)
117112
eval_dataset = get_dataset_fn(select,

python/runtime/pai/tensorflow/explain.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,12 @@
2121
import runtime
2222
import tensorflow as tf
2323
from runtime import oss
24+
from runtime.import_model import import_model
2425
from runtime.tensorflow import is_tf_estimator
2526
from runtime.tensorflow.explain import explain_boosted_trees, explain_dnns
2627
from runtime.tensorflow.input_fn import input_fn
2728
from runtime.tensorflow.keras_with_feature_column_input import \
2829
init_model_with_feature_column
29-
from tensorflow.estimator import (BoostedTreesClassifier,
30-
BoostedTreesRegressor, DNNClassifier,
31-
DNNLinearCombinedClassifier,
32-
DNNLinearCombinedRegressor, DNNRegressor,
33-
LinearClassifier, LinearRegressor)
3430

3531
try:
3632
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
@@ -59,7 +55,7 @@ def explain(datasource, select, data_table, result_table, label_column,
5955
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
6056
# because predicting do not need these parameters.
6157

62-
is_estimator = is_tf_estimator(eval(estimator))
58+
is_estimator = is_tf_estimator(import_model(estimator))
6359

6460
# Keras single node is using h5 format to save the model, no need to deal with export model format.
6561
# Keras distributed mode will use estimator, so this is also needed.
@@ -106,8 +102,7 @@ def _explain(datasource,
106102
oss_sk=None,
107103
oss_endpoint=None,
108104
oss_bucket_name=None):
109-
runtime.import_model_def(estimator_string, globals())
110-
estimator_cls = eval(estimator_string)
105+
estimator_cls = import_model(estimator_string)
111106
FLAGS = tf.app.flags.FLAGS
112107
model_params["model_dir"] = FLAGS.checkpointDir
113108
model_params.update(feature_columns)

python/runtime/pai/tensorflow/predict.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@
1919
import tensorflow as tf
2020
from runtime import db, oss
2121
from runtime.diagnostics import SQLFlowDiagnostic
22+
from runtime.import_model import import_model
2223
from runtime.pai.pai_distributed import define_tf_flags
2324
from runtime.tensorflow import is_tf_estimator
2425
from runtime.tensorflow.predict import estimator_predict, keras_predict
25-
from tensorflow.estimator import (BoostedTreesClassifier,
26-
BoostedTreesRegressor, DNNClassifier,
27-
DNNLinearCombinedClassifier,
28-
DNNLinearCombinedRegressor, DNNRegressor,
29-
LinearClassifier, LinearRegressor)
3026

3127
try:
3228
import sqlflow_models
@@ -65,7 +61,7 @@ def predict(datasource, select, data_table, result_table, label_column,
6561
# NOTE(typhoonzero): No need to eval model_params["optimizer"] and model_params["loss"]
6662
# because predicting do not need these parameters.
6763

68-
is_estimator = is_tf_estimator(eval(estimator))
64+
is_estimator = is_tf_estimator(import_model(estimator))
6965

7066
# Keras single node is using h5 format to save the model, no need to deal with export model format.
7167
# Keras distributed mode will use estimator, so this is also needed.
@@ -106,8 +102,7 @@ def _predict(datasource,
106102
save="",
107103
batch_size=1,
108104
pai_table=""):
109-
runtime.import_model_def(estimator_string, globals())
110-
estimator = eval(estimator_string)
105+
estimator = import_model(estimator_string)
111106
model_params.update(feature_columns)
112107
is_estimator = is_tf_estimator(estimator)
113108

python/runtime/pai/tensorflow/train.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from runtime import oss
2626
from runtime.db import (connect_with_data_source, db_generator,
2727
parseMaxComputeDSN)
28+
from runtime.import_model import import_model
2829
from runtime.model_metadata import collect_model_metadata
2930
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
3031
from runtime.pai.tensorflow.train_estimator import estimator_train_and_save
@@ -33,11 +34,6 @@
3334
from runtime.tensorflow.get_tf_version import tf_is_version2
3435
from runtime.tensorflow.input_fn import get_dataset_fn
3536
from runtime.tensorflow.set_log_level import set_log_level
36-
from tensorflow.estimator import (BoostedTreesClassifier,
37-
BoostedTreesRegressor, DNNClassifier,
38-
DNNLinearCombinedClassifier,
39-
DNNLinearCombinedRegressor, DNNRegressor,
40-
LinearClassifier, LinearRegressor)
4137

4238
try:
4339
import sqlflow_models
@@ -81,8 +77,7 @@ def train(datasource,
8177
model_params, feature_columns_code,
8278
feature_metas, label_meta, None,
8379
model_repo_image)
84-
runtime.import_model_def(estimator_string, globals())
85-
estimator = eval(estimator_string)
80+
estimator = import_model(estimator_string)
8681
is_estimator = is_tf_estimator(estimator)
8782

8883
if verbose < 1: # always use verbose == 1 when using PAI to get more logs

python/runtime/tensorflow/estimator_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
if __name__ == "__main__":
8383
# tf.python.training.basic_session_run_hooks.LoggingTensorHook = runtime.tensorflow.train.PrintTensorsHook
8484
train(datasource=datasource,
85-
estimator_string="tf.estimator.DNNClassifier",
85+
estimator_string="DNNClassifier",
8686
select=select,
8787
validation_select=validate_select,
8888
feature_columns=feature_columns,
@@ -98,7 +98,7 @@
9898
epoch=3,
9999
verbose=0)
100100
train(datasource=datasource,
101-
estimator_string="tf.estimator.DNNClassifier",
101+
estimator_string="DNNClassifier",
102102
select=select_binary,
103103
validation_select=validate_select_binary,
104104
feature_columns=feature_columns,
@@ -114,7 +114,7 @@
114114
epoch=3,
115115
verbose=1)
116116
pred(datasource=datasource,
117-
estimator_string="tf.estimator.DNNClassifier",
117+
estimator_string="DNNClassifier",
118118
select=select,
119119
result_table="iris.predict",
120120
feature_columns=feature_columns,

python/runtime/tensorflow/evaluate.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,13 @@
1818
import runtime
1919
import tensorflow as tf
2020
from runtime.db import buffered_db_writer, connect_with_data_source
21+
from runtime.import_model import import_model
2122
from runtime.tensorflow import metrics
2223
from runtime.tensorflow.get_tf_model_type import is_tf_estimator
2324
from runtime.tensorflow.input_fn import get_dataset_fn
2425
from runtime.tensorflow.keras_with_feature_column_input import \
2526
init_model_with_feature_column
2627
from runtime.tensorflow.set_log_level import set_log_level
27-
from tensorflow.estimator import (BoostedTreesClassifier,
28-
BoostedTreesRegressor, DNNClassifier,
29-
DNNLinearCombinedClassifier,
30-
DNNLinearCombinedRegressor, DNNRegressor,
31-
LinearClassifier, LinearRegressor)
3228

3329
try:
3430
import sqlflow_models
@@ -54,8 +50,7 @@ def evaluate(datasource,
5450
hive_location="",
5551
hdfs_user="",
5652
hdfs_pass=""):
57-
runtime.import_model_def(estimator_string, globals())
58-
estimator_cls = eval(estimator_string)
53+
estimator_cls = import_model(estimator_string)
5954
is_estimator = is_tf_estimator(estimator_cls)
6055
set_log_level(verbose, is_estimator)
6156
eval_dataset = get_dataset_fn(select,

0 commit comments

Comments
 (0)