Skip to content

Commit b539fdd

Browse files
authored
refine codes (#2760)
1 parent c59d660 commit b539fdd

19 files changed

Lines changed: 77 additions & 131 deletions

python/runtime/import_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from tensorflow.estimator import LinearRegressor # noqa: F401
2727

2828

29-
def import_model_module(model, namespace):
29+
def import_model_package(model, namespace):
3030
"""
31-
Import the model module into namespace. For example,
32-
If model = "my_model_module.my_model", "my_model_module"
31+
Import the model package into namespace. For example,
32+
If model = "my_model_package.my_model", "my_model_package"
3333
would be imported into namespace.
3434
3535
Args:
@@ -43,12 +43,12 @@ def import_model_module(model, namespace):
4343
# format: my_model_package.MyModel
4444
model_name_parts = model.split(".")
4545
if len(model_name_parts) == 2:
46-
module = model_name_parts[0]
47-
if module and module.lower() not in ['xgboost', 'sqlflow_models']:
46+
package = model_name_parts[0]
47+
if package and package.lower() not in ['xgboost', 'sqlflow_models']:
4848
try:
49-
namespace[module] = __import__(module)
49+
namespace[package] = __import__(package)
5050
except Exception as e:
51-
print("failed to import %s: %s" % (module, e))
51+
print("failed to import %s: %s" % (package, e))
5252

5353

5454
def import_model(model):
@@ -61,5 +61,5 @@ def import_model(model):
6161
Returns:
6262
An imported model class or function.
6363
"""
64-
import_model_module(model, globals())
64+
import_model_package(model, globals())
6565
return eval(model)

python/runtime/tensorflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.tensorflow.get_tf_model_type import is_tf_estimator
14+
from runtime.tensorflow.get_tf_model_type import is_tf_estimator # noqa: F401

python/runtime/tensorflow/estimator_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@
8080
}
8181

8282
if __name__ == "__main__":
83-
# tf.python.training.basic_session_run_hooks.LoggingTensorHook = runtime.tensorflow.train.PrintTensorsHook
83+
# tf.python.training.basic_session_run_hooks.LoggingTensorHook
84+
# = runtime.tensorflow.train.PrintTensorsHook
8485
train(datasource=datasource,
8586
estimator_string="DNNClassifier",
8687
select=select,

python/runtime/tensorflow/evaluate.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import os
1514
import sys
16-
import types
1715

18-
import runtime
19-
import tensorflow as tf
2016
from runtime.db import buffered_db_writer, connect_with_data_source
2117
from runtime.import_model import import_model
2218
from runtime.tensorflow import metrics
@@ -26,11 +22,6 @@
2622
init_model_with_feature_column
2723
from runtime.tensorflow.set_log_level import set_log_level
2824

29-
try:
30-
import sqlflow_models
31-
except:
32-
pass
33-
3425

3526
def evaluate(datasource,
3627
estimator_string,
@@ -100,11 +91,12 @@ def estimator_evaluate(estimator, eval_dataset, validation_metrics):
10091
if val:
10192
result_metrics[m] = val
10293
else:
103-
# NOTE: estimator automatically append metrics for the current evaluation job,
104-
# if user specified metrics not appear in estimator's result dict, fill None.
94+
# NOTE: estimator automatically append metrics for the current
95+
# evaluation job, if user specified metrics not appear in
96+
# estimator's result dict, fill None.
10597
print(
106-
"specified metric %s not calculated by estimator, fill empty value."
107-
% m)
98+
"specified metric %s not calculated by estimator, fill empty "
99+
"value." % m)
108100
result_metrics[m] = None
109101

110102
return result_metrics
@@ -128,23 +120,20 @@ def keras_evaluate(keras_model, eval_dataset_fn, save, keras_model_pkg,
128120
# default
129121
keras_metrics = metrics.get_keras_metrics(["Accuracy"])
130122

131-
# compile the model with default arguments only for evaluation (run forward only).
123+
# compile the model with default arguments only for evaluation (run forward
124+
# only).
132125
keras_model.compile(loss=keras_model_pkg.loss, metrics=keras_metrics)
133126

134127
eval_dataset = eval_dataset_fn()
135128

136129
def get_features(sample, label):
137130
return sample
138131

139-
def get_label(sample, label):
140-
return label
141-
142132
eval_dataset_x = eval_dataset.map(get_features)
143-
eval_dataset_y = eval_dataset.map(get_label)
144133

145134
one_batch = next(iter(eval_dataset_x))
146135
# NOTE: must run predict one batch to initialize parameters
147-
# see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models
136+
# see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
148137
keras_model.predict_on_batch(one_batch)
149138
keras_model.load_weights(save)
150139
result = keras_model.evaluate(eval_dataset)

python/runtime/tensorflow/evaluate_example.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import shutil
15-
16-
import runtime
17-
import sqlflow_models
18-
import tensorflow as tf
1914
from estimator_example import (datasource, feature_column_names,
20-
feature_columns, feature_metas, label_meta,
21-
select_binary, validate_select_binary)
15+
feature_columns, feature_metas, label_meta)
2216
from runtime.tensorflow.evaluate import evaluate
2317
from runtime.tensorflow.train import train
2418

python/runtime/tensorflow/explain.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313

1414
import os
1515

16-
import matplotlib
1716
import matplotlib.pyplot as plt
1817
import numpy as np
1918
import pandas as pd
20-
import runtime
2119
import seaborn as sns
2220
import shap
2321
import tensorflow as tf
@@ -35,11 +33,6 @@
3533
# Use non-interactive background
3634
plt.switch_backend('agg')
3735

38-
try:
39-
import sqlflow_models
40-
except:
41-
pass
42-
4336
# Disable Tensorflow INFO and WARNING logs
4437
if tf_is_version2():
4538
import logging

python/runtime/tensorflow/explain_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import shutil
1515

1616
from estimator_example import (datasource, feature_column_names,
17-
feature_columns, feature_metas, label_meta,
18-
select_binary, validate_select_binary)
17+
feature_columns, feature_metas, label_meta)
1918
from runtime.tensorflow.explain import explain
2019
from runtime.tensorflow.train import train
2120

python/runtime/tensorflow/input_fn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def parse_pai_dataset(feature_column_names, label_meta, feature_metas, *row):
152152
features[name] = tf.SparseTensor(*f) if spec["is_sparse"] else f
153153
label = row[-1] if label_meta["feature_name"] else -1
154154
if label_meta and label_meta["delimiter"] != "":
155-
# FIXME(typhoonzero): the label in the yielded row may not be the last item, should get
156-
# label index.
155+
# FIXME(typhoonzero): the label in the yielded row may not be the last
156+
# item, should get label index.
157157
tmp = tf.strings.split(label,
158158
sep=label_meta["delimiter"],
159159
result_type='RaggedTensor')
@@ -171,7 +171,6 @@ def pai_dataset(table,
171171
feature_metas,
172172
slice_id=0,
173173
slice_count=1):
174-
record_defaults = []
175174
selected_cols = copy.copy(feature_column_names)
176175
dtypes = [
177176
"string"

python/runtime/tensorflow/keras_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
# limitations under the License.
1313

1414
import os
15-
# NOTE: this file is used by train_predict_test.py, do **NOT** delete!
16-
import shutil
1715

18-
import sqlflow_models
1916
# TODO(yancey1989): this import line would conflict with isort pre-commit stage
2017
# yapf: disable
2118
from runtime.tensorflow.estimator_example import (datasource,
@@ -27,6 +24,9 @@
2724
from runtime.tensorflow.predict import pred
2825
from runtime.tensorflow.train import train
2926

27+
# NOTE: this file is used by train_predict_test.py, do **NOT** delete!
28+
29+
3030
if __name__ == "__main__":
3131
train(datasource=datasource,
3232
estimator_string="sqlflow_models.DNNClassifier",

python/runtime/tensorflow/keras_example_reg.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
# limitations under the License.
1313

1414
import os
15-
# NOTE: this file is used by train_predict_test.py, do **NOT** delete!
16-
import shutil
1715

18-
import sqlflow_models
1916
import tensorflow as tf
2017
from runtime.tensorflow.estimator_example import datasource
2118
from runtime.tensorflow.predict import pred
2219
from runtime.tensorflow.train import train
2320

21+
# NOTE: this file is used by train_predict_test.py, do **NOT** delete!
22+
2423
select = "select * from housing.train"
2524
validation_select = "select * from housing.test"
2625

@@ -79,13 +78,6 @@
7978
"shape": [1],
8079
"is_sparse": "false" == "true"
8180
},
82-
"f6": {
83-
"feature_name": "f6",
84-
"dtype": "float32",
85-
"delimiter": "",
86-
"shape": [1],
87-
"is_sparse": "false" == "true"
88-
},
8981
"f7": {
9082
"feature_name": "f7",
9183
"dtype": "float32",

0 commit comments

Comments
 (0)