|
21 | 21 | from runtime import oss |
22 | 22 | from runtime.pai.pai_distributed import define_tf_flags |
23 | 23 | from runtime.tensorflow import is_tf_estimator |
24 | | -from runtime.tensorflow.evaluate import evaluate as _evaluate |
| 24 | +from runtime.tensorflow.evaluate import (estimator_evaluate, keras_evaluate, |
| 25 | + write_result_metrics) |
| 26 | +from runtime.tensorflow.input_fn import get_dataset_fn |
| 27 | +from runtime.tensorflow.keras_with_feature_column_input import \ |
| 28 | + init_model_with_feature_column |
| 29 | +from runtime.tensorflow.set_log_level import set_log_level |
25 | 30 | from tensorflow.estimator import (BoostedTreesClassifier, |
26 | 31 | BoostedTreesRegressor, DNNClassifier, |
27 | 32 | DNNLinearCombinedClassifier, |
@@ -88,3 +93,57 @@ def evaluate(datasource, select, data_table, result_table, oss_model_path, |
88 | 93 | verbose=0, |
89 | 94 | is_pai=True, |
90 | 95 | pai_table=data_table) |
| 96 | + |
| 97 | + |
| 98 | +def _evaluate(datasource, |
| 99 | + estimator_string, |
| 100 | + select, |
| 101 | + result_table, |
| 102 | + feature_columns, |
| 103 | + feature_column_names, |
| 104 | + feature_metas={}, |
| 105 | + label_meta={}, |
| 106 | + model_params={}, |
| 107 | + validation_metrics=["Accuracy"], |
| 108 | + save="", |
| 109 | + batch_size=1, |
| 110 | + validation_steps=None, |
| 111 | + verbose=0, |
| 112 | + pai_table=""): |
| 113 | + runtime.import_model_def(estimator_string, globals()) |
| 114 | + estimator_cls = eval(estimator_string) |
| 115 | + is_estimator = is_tf_estimator(estimator_cls) |
| 116 | + set_log_level(verbose, is_estimator) |
| 117 | + eval_dataset = get_dataset_fn(select, |
| 118 | + datasource, |
| 119 | + feature_column_names, |
| 120 | + feature_metas, |
| 121 | + label_meta, |
| 122 | + is_pai=True, |
| 123 | + pai_table=pai_table, |
| 124 | + batch_size=batch_size) |
| 125 | + |
| 126 | + model_params.update(feature_columns) |
| 127 | + if is_estimator: |
| 128 | + FLAGS = tf.app.flags.FLAGS |
| 129 | + model_params["model_dir"] = FLAGS.checkpointDir |
| 130 | + estimator = estimator_cls(**model_params) |
| 131 | + result_metrics = estimator_evaluate(estimator, eval_dataset, |
| 132 | + validation_metrics) |
| 133 | + else: |
| 134 | + keras_model = init_model_with_feature_column(estimator, model_params) |
| 135 | + keras_model_pkg = sys.modules[estimator_cls.__module__] |
| 136 | + result_metrics = keras_evaluate(keras_model, eval_dataset, save, |
| 137 | + keras_model_pkg, validation_metrics) |
| 138 | + |
| 139 | + if result_table: |
| 140 | + metric_name_list = ["loss"] + validation_metrics |
| 141 | + write_result_metrics(result_metrics, |
| 142 | + metric_name_list, |
| 143 | + result_table, |
| 144 | + "pai_maxcompute", |
| 145 | + None, |
| 146 | + hdfs_namenode_addr="", |
| 147 | + hive_location="", |
| 148 | + hdfs_user="", |
| 149 | + hdfs_pass="") |
0 commit comments