Skip to content

Commit 05a5b97

Browse files
authored
Remove is pai for evaluate (#2749)
* remove is_pai for evaluate * update
1 parent 9a34bd4 commit 05a5b97

4 files changed

Lines changed: 75 additions & 29 deletions

File tree

go/codegen/pai/template_tf.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ if os.environ.get('DISPLAY', '') == '':
273273
import json
274274
import types
275275
import sys
276-
from runtime.tensorflow import evaluate
276+
from runtime.pai.tensorflow import evaluate
277277
278278
try:
279279
tf.enable_eager_execution()
@@ -307,7 +307,7 @@ if is_estimator:
307307
else:
308308
oss.load_file("{{.OSSModelDir}}", "model_save")
309309
310-
evaluate.evaluate(datasource="{{.DataSource}}",
310+
evaluate._evaluate(datasource="{{.DataSource}}",
311311
estimator_string=estimator,
312312
select="""{{.Select}}""",
313313
result_table="{{.ResultTable}}",
@@ -321,6 +321,5 @@ evaluate.evaluate(datasource="{{.DataSource}}",
321321
batch_size=1,
322322
validation_steps=None,
323323
verbose=0,
324-
is_pai="{{.IsPAI}}" == "true",
325324
pai_table="{{.PAITable}}")
326325
`

go/codegen/tensorflow/template_evaluate.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,5 @@ evaluate(datasource="{{.DataSource}}",
104104
hdfs_namenode_addr="{{.HDFSNameNodeAddr}}",
105105
hive_location="{{.HiveLocation}}",
106106
hdfs_user="{{.HDFSUser}}",
107-
hdfs_pass="{{.HDFSPass}}",
108-
is_pai="{{.IsPAI}}" == "true",
109-
pai_table="{{.PAIEvaluateTable}}")
107+
hdfs_pass="{{.HDFSPass}}")
110108
`

python/runtime/pai/tensorflow/evaluate.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from runtime import oss
2222
from runtime.pai.pai_distributed import define_tf_flags
2323
from runtime.tensorflow import is_tf_estimator
24-
from runtime.tensorflow.evaluate import evaluate as _evaluate
24+
from runtime.tensorflow.evaluate import (estimator_evaluate, keras_evaluate,
25+
write_result_metrics)
26+
from runtime.tensorflow.input_fn import get_dataset_fn
27+
from runtime.tensorflow.keras_with_feature_column_input import \
28+
init_model_with_feature_column
29+
from runtime.tensorflow.set_log_level import set_log_level
2530
from tensorflow.estimator import (BoostedTreesClassifier,
2631
BoostedTreesRegressor, DNNClassifier,
2732
DNNLinearCombinedClassifier,
@@ -88,3 +93,57 @@ def evaluate(datasource, select, data_table, result_table, oss_model_path,
8893
verbose=0,
8994
is_pai=True,
9095
pai_table=data_table)
96+
97+
98+
def _evaluate(datasource,
99+
estimator_string,
100+
select,
101+
result_table,
102+
feature_columns,
103+
feature_column_names,
104+
feature_metas={},
105+
label_meta={},
106+
model_params={},
107+
validation_metrics=["Accuracy"],
108+
save="",
109+
batch_size=1,
110+
validation_steps=None,
111+
verbose=0,
112+
pai_table=""):
113+
runtime.import_model_def(estimator_string, globals())
114+
estimator_cls = eval(estimator_string)
115+
is_estimator = is_tf_estimator(estimator_cls)
116+
set_log_level(verbose, is_estimator)
117+
eval_dataset = get_dataset_fn(select,
118+
datasource,
119+
feature_column_names,
120+
feature_metas,
121+
label_meta,
122+
is_pai=True,
123+
pai_table=pai_table,
124+
batch_size=batch_size)
125+
126+
model_params.update(feature_columns)
127+
if is_estimator:
128+
FLAGS = tf.app.flags.FLAGS
129+
model_params["model_dir"] = FLAGS.checkpointDir
130+
estimator = estimator_cls(**model_params)
131+
result_metrics = estimator_evaluate(estimator, eval_dataset,
132+
validation_metrics)
133+
else:
134+
keras_model = init_model_with_feature_column(estimator, model_params)
135+
keras_model_pkg = sys.modules[estimator_cls.__module__]
136+
result_metrics = keras_evaluate(keras_model, eval_dataset, save,
137+
keras_model_pkg, validation_metrics)
138+
139+
if result_table:
140+
metric_name_list = ["loss"] + validation_metrics
141+
write_result_metrics(result_metrics,
142+
metric_name_list,
143+
result_table,
144+
"pai_maxcompute",
145+
None,
146+
hdfs_namenode_addr="",
147+
hive_location="",
148+
hdfs_user="",
149+
hdfs_pass="")

python/runtime/tensorflow/evaluate.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,23 @@ def evaluate(datasource,
5353
hdfs_namenode_addr="",
5454
hive_location="",
5555
hdfs_user="",
56-
hdfs_pass="",
57-
is_pai=False,
58-
pai_table=""):
56+
hdfs_pass=""):
5957
runtime.import_model_def(estimator_string, globals())
6058
estimator_cls = eval(estimator_string)
61-
6259
is_estimator = is_tf_estimator(estimator_cls)
63-
6460
set_log_level(verbose, is_estimator)
65-
66-
eval_dataset = get_dataset_fn(select, datasource, feature_column_names,
67-
feature_metas, label_meta, is_pai, pai_table,
68-
batch_size)
61+
eval_dataset = get_dataset_fn(select,
62+
datasource,
63+
feature_column_names,
64+
feature_metas,
65+
label_meta,
66+
is_pai=False,
67+
pai_table="",
68+
batch_size=batch_size)
6969

7070
model_params.update(feature_columns)
7171
if is_estimator:
72-
if is_pai:
73-
FLAGS = tf.app.flags.FLAGS
74-
model_params["model_dir"] = FLAGS.checkpointDir
75-
else:
76-
model_params["model_dir"] = save
77-
# tf estimator always have feature_column argument
72+
model_params["model_dir"] = save
7873
estimator = estimator_cls(**model_params)
7974
result_metrics = estimator_evaluate(estimator, eval_dataset,
8075
validation_metrics)
@@ -85,13 +80,8 @@ def evaluate(datasource,
8580
keras_model_pkg, validation_metrics)
8681

8782
# write result metrics to a table
88-
if is_pai:
89-
driver = "pai_maxcompute"
90-
conn = None
91-
else:
92-
conn = connect_with_data_source(datasource)
93-
driver = conn.driver
94-
83+
conn = connect_with_data_source(datasource)
84+
driver = conn.driver
9585
if result_table:
9686
metric_name_list = ["loss"] + validation_metrics
9787
write_result_metrics(result_metrics,

0 commit comments

Comments
 (0)