1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313
14- import os
1514import sys
16- import types
1715
18- import runtime
19- import tensorflow as tf
2016from runtime .db import buffered_db_writer , connect_with_data_source
2117from runtime .import_model import import_model
2218from runtime .tensorflow import metrics
2622 init_model_with_feature_column
2723from runtime .tensorflow .set_log_level import set_log_level
2824
29- try :
30- import sqlflow_models
31- except :
32- pass
33-
3425
3526def evaluate (datasource ,
3627 estimator_string ,
@@ -100,11 +91,12 @@ def estimator_evaluate(estimator, eval_dataset, validation_metrics):
10091 if val :
10192 result_metrics [m ] = val
10293 else :
103- # NOTE: estimator automatically append metrics for the current evaluation job,
104- # if user specified metrics not appear in estimator's result dict, fill None.
94+ # NOTE: estimator automatically append metrics for the current
95+ # evaluation job, if user specified metrics not appear in
96+ # estimator's result dict, fill None.
10597 print (
106- "specified metric %s not calculated by estimator, fill empty value. "
107- % m )
98+ "specified metric %s not calculated by estimator, fill empty "
99+ "value." % m )
108100 result_metrics [m ] = None
109101
110102 return result_metrics
@@ -128,23 +120,20 @@ def keras_evaluate(keras_model, eval_dataset_fn, save, keras_model_pkg,
128120 # default
129121 keras_metrics = metrics .get_keras_metrics (["Accuracy" ])
130122
131- # compile the model with default arguments only for evaluation (run forward only).
123+ # compile the model with default arguments only for evaluation (run forward
124+ # only).
132125 keras_model .compile (loss = keras_model_pkg .loss , metrics = keras_metrics )
133126
134127 eval_dataset = eval_dataset_fn ()
135128
136129 def get_features (sample , label ):
137130 return sample
138131
139- def get_label (sample , label ):
140- return label
141-
142132 eval_dataset_x = eval_dataset .map (get_features )
143- eval_dataset_y = eval_dataset .map (get_label )
144133
145134 one_batch = next (iter (eval_dataset_x ))
146135 # NOTE: must run predict one batch to initialize parameters
147- # see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models
136+ # see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
148137 keras_model .predict_on_batch (one_batch )
149138 keras_model .load_weights (save )
150139 result = keras_model .evaluate (eval_dataset )
0 commit comments