2323
2424from ..model_metadata import save_model_metadata
2525from . import metrics
26- from .diag import check_and_load_estimator
2726from .get_tf_version import tf_is_version2
2827from .input_fn import input_fn
29- from .keras_with_feature_column_input import WrappedKerasModel
28+ from .keras_with_feature_column_input import init_model_with_feature_column
3029from .pai_distributed import (dump_into_tf_config ,
3130 make_distributed_info_without_evaluator )
3231from .train_estimator import estimator_train_compiled
@@ -47,13 +46,6 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
4746 loss = model_params ["loss" ]
4847 del model_params ["loss" ]
4948
50- signature = inspect .signature (estimator )
51- has_feature_columns_arg = False
52- for p in signature .parameters :
53- if signature .parameters [p ].name == "feature_columns" :
54- has_feature_columns_arg = True
55- break
56-
5749 classifier_pkg = sys .modules [estimator .__module__ ]
5850 # setting training metrics
5951 model_metrics = []
@@ -101,13 +93,8 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
10193 else :
10294 validate_dataset = None
10395
104- if not has_feature_columns_arg and not has_none_optimizer :
105- feature_columns = model_params ["feature_columns" ]
106- del model_params ["feature_columns" ]
107- classifier = WrappedKerasModel (estimator , model_params ,
108- feature_columns )
109- else :
110- classifier = check_and_load_estimator (estimator , model_params )
96+ classifier = init_model_with_feature_column (
97+ estimator , model_params , has_none_optimizer = has_none_optimizer )
11198
11299 # FIXME(sneaxiy): some models defined by other framework (not TensorFlow or XGBoost)
113100 # may return None optimizer.
0 commit comments