Skip to content

Commit befcd74

Browse files
authored
Polish keras wrapped model init with feature columns (#2487)
* polish keras wrapped model init with feature columns * update * update * update
1 parent 82ad9cb commit befcd74

7 files changed

Lines changed: 46 additions & 37 deletions

File tree

python/sqlflow_submitter/tensorflow/diag.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ class SQLFlowDiagnostic(Exception):
2222
pass
2323

2424

25-
def check_and_load_estimator(estimator, model_params, warm_start_from=None):
25+
def load_pretrained_model_estimator(estimator,
26+
model_params,
27+
warm_start_from=None):
2628
if warm_start_from is not None:
2729
estimator_func = estimator.__init__ if inspect.isclass(
2830
estimator) else estimator
2931
estimator_spec = inspect.getargspec(estimator_func)
3032
# The constructor of Estimator contains named parameter "warm_start_from"
3133
warm_start_from_key = "warm_start_from"
3234
if warm_start_from_key in estimator_spec.args:
33-
model_params = copy.copy(model_params)
3435
warm_start_from = os.path.abspath(warm_start_from)
3536

3637
if is_tf_estimator(estimator):
@@ -48,6 +49,8 @@ def check_and_load_estimator(estimator, model_params, warm_start_from=None):
4849
"Incremental training is not supported in {}".format(
4950
estimator))
5051

52+
53+
def init_model(estimator, model_params):
5154
# load estimator class and diagnose the type error
5255
try:
5356
return estimator(**model_params)

python/sqlflow_submitter/tensorflow/evaluate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from . import metrics
2929
from .input_fn import get_dataset_fn
30+
from .keras_with_feature_column_input import init_model_with_feature_column
3031
from .pai_distributed import define_tf_flags
3132
from .set_log_level import set_log_level
3233

@@ -85,11 +86,12 @@ def evaluate(datasource,
8586
model_params["model_dir"] = FLAGS.checkpointDir
8687
else:
8788
model_params["model_dir"] = save
89+
# tf estimator always have feature_column argument
8890
estimator = estimator_cls(**model_params)
8991
result_metrics = estimator_evaluate(estimator, eval_dataset,
9092
validation_metrics)
9193
else:
92-
keras_model = estimator_cls(**model_params)
94+
keras_model = init_model_with_feature_column(estimator, model_params)
9395
keras_model_pkg = sys.modules[estimator_cls.__module__]
9496
result_metrics = keras_evaluate(keras_model, eval_dataset, save,
9597
keras_model_pkg, validation_metrics)

python/sqlflow_submitter/tensorflow/explain.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from .get_tf_version import tf_is_version2
3333
from .input_fn import input_fn
34+
from .keras_with_feature_column_input import init_model_with_feature_column
3435

3536
sns_colors = sns.color_palette('colorblind')
3637
# Disable Tensorflow INFO and WARNING logs
@@ -97,7 +98,9 @@ def _input_fn():
9798
return dataset.batch(1).cache()
9899

99100
model_params.update(feature_columns)
100-
estimator = estimator_cls(**model_params)
101+
102+
estimator = init_model_with_feature_column(estimator_cls, model_params)
103+
101104
if estimator_cls in (tf.estimator.BoostedTreesClassifier,
102105
tf.estimator.BoostedTreesRegressor):
103106
explain_boosted_trees(datasource, estimator, _input_fn, plot_type,

python/sqlflow_submitter/tensorflow/keras_with_feature_column_input.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import inspect
15+
1416
import tensorflow as tf
1517

18+
from .diag import init_model
19+
1620

1721
class WrappedKerasModel(tf.keras.Model):
1822
def __init__(self, keras_model, model_params, feature_columns):
@@ -23,3 +27,23 @@ def __init__(self, keras_model, model_params, feature_columns):
2327
def __call__(self, inputs, training=True):
2428
x = self.feature_layer(inputs)
2529
return self.sub_model.__call__(x, training=training)
30+
31+
32+
def init_model_with_feature_column(estimator,
33+
model_params,
34+
has_none_optimizer=False):
35+
"""Check if estimator have argument "feature_column" and initialize the model
36+
by wrapping the keras model if no "feature_column" argument detected.
37+
38+
NOTE: initalize estimator model can also use this function since estimators all have
39+
"feature_column" argument.
40+
"""
41+
argspec = inspect.getargspec(estimator)
42+
if "feature_columns" not in argspec.args and not has_none_optimizer:
43+
feature_columns = model_params["feature_columns"]
44+
del model_params["feature_columns"]
45+
classifier = WrappedKerasModel(estimator, model_params,
46+
feature_columns)
47+
else:
48+
classifier = init_model(estimator, model_params)
49+
return classifier

python/sqlflow_submitter/tensorflow/predict.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from .get_tf_version import tf_is_version2
3434
from .input_fn import get_dtype, parse_sparse_feature_predict, tf_generator
35-
from .keras_with_feature_column_input import WrappedKerasModel
35+
from .keras_with_feature_column_input import init_model_with_feature_column
3636

3737
try:
3838
import sqlflow_models
@@ -57,19 +57,8 @@ def keras_predict(estimator, model_params, save, result_table, is_pai,
5757
pai_table, feature_column_names, feature_metas,
5858
result_col_name, datasource, select, hdfs_namenode_addr,
5959
hive_location, hdfs_user, hdfs_pass):
60-
signature = inspect.signature(estimator)
61-
has_feature_columns_arg = False
62-
for p in signature.parameters:
63-
if signature.parameters[p].name == "feature_columns":
64-
has_feature_columns_arg = True
65-
break
66-
if not has_feature_columns_arg:
67-
feature_columns = model_params["feature_columns"]
68-
del model_params["feature_columns"]
69-
classifier = WrappedKerasModel(estimator, model_params,
70-
feature_columns)
71-
else:
72-
classifier = estimator(**model_params)
60+
61+
classifier = init_model_with_feature_column(estimator, model_params)
7362
classifier_pkg = sys.modules[estimator.__module__]
7463
conn = None
7564
if is_pai:

python/sqlflow_submitter/tensorflow/train_estimator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..model_metadata import save_model_metadata
1919
from . import metrics
20-
from .diag import check_and_load_estimator
20+
from .diag import init_model, load_pretrained_model_estimator
2121
from .get_tf_version import tf_is_version2
2222
from .input_fn import input_fn
2323
from .pai_distributed import make_estimator_distributed_runconfig
@@ -46,8 +46,9 @@ def estimator_train_and_save(estimator, model_params, save, is_pai, FLAGS,
4646
model_params["model_dir"] = save
4747

4848
warm_start_from = save if load_pretrained_model else None
49-
classifier = check_and_load_estimator(estimator, model_params,
50-
warm_start_from)
49+
if warm_start_from:
50+
load_pretrained_model_estimator(estimator, model_params)
51+
classifier = init_model(estimator, model_params)
5152

5253
# do not add default Accuracy metric when using estimator to train, it will fail
5354
# when the estimator is a regressor, and estimator seems automatically add some

python/sqlflow_submitter/tensorflow/train_keras.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323

2424
from ..model_metadata import save_model_metadata
2525
from . import metrics
26-
from .diag import check_and_load_estimator
2726
from .get_tf_version import tf_is_version2
2827
from .input_fn import input_fn
29-
from .keras_with_feature_column_input import WrappedKerasModel
28+
from .keras_with_feature_column_input import init_model_with_feature_column
3029
from .pai_distributed import (dump_into_tf_config,
3130
make_distributed_info_without_evaluator)
3231
from .train_estimator import estimator_train_compiled
@@ -47,13 +46,6 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
4746
loss = model_params["loss"]
4847
del model_params["loss"]
4948

50-
signature = inspect.signature(estimator)
51-
has_feature_columns_arg = False
52-
for p in signature.parameters:
53-
if signature.parameters[p].name == "feature_columns":
54-
has_feature_columns_arg = True
55-
break
56-
5749
classifier_pkg = sys.modules[estimator.__module__]
5850
# setting training metrics
5951
model_metrics = []
@@ -101,13 +93,8 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
10193
else:
10294
validate_dataset = None
10395

104-
if not has_feature_columns_arg and not has_none_optimizer:
105-
feature_columns = model_params["feature_columns"]
106-
del model_params["feature_columns"]
107-
classifier = WrappedKerasModel(estimator, model_params,
108-
feature_columns)
109-
else:
110-
classifier = check_and_load_estimator(estimator, model_params)
96+
classifier = init_model_with_feature_column(
97+
estimator, model_params, has_none_optimizer=has_none_optimizer)
11198

11299
# FIXME(sneaxiy): some models defined by other framework (not TensorFlow or XGBoost)
113100
# may return None optimizer.

0 commit comments

Comments
 (0)