Skip to content

Commit 6012501

Browse files
authored
Save model meta into 'model_meta.json' (#2476)
* merge files * Design doc for model metadata storage * refine wording * add field description to metadata * add metadata definition * Save model metadata into model_meta.json * sort import
1 parent c8b4432 commit 6012501

12 files changed

Lines changed: 160 additions & 40 deletions

File tree

pkg/codegen/tensorflow/codegen.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ import (
1818
"encoding/json"
1919
"fmt"
2020
"os"
21-
"sqlflow.org/sqlflow/pkg/codegen"
2221
"strings"
2322
"text/template"
2423

24+
"sqlflow.org/sqlflow/pkg/codegen"
25+
2526
"sqlflow.org/sqlflow/pkg/attribute"
2627
"sqlflow.org/sqlflow/pkg/ir"
2728
pb "sqlflow.org/sqlflow/pkg/proto"
@@ -333,6 +334,7 @@ func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
333334
IsPAI: IsPAI(),
334335
PAITrainTable: paiTrainTable,
335336
PAIValidateTable: paiValidateTable,
337+
ModelRepoImage: trainStmt.ModelImage,
336338
}
337339
var program bytes.Buffer
338340
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{

pkg/codegen/tensorflow/codegen_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ func TestTrainCodegen(t *testing.T) {
5353
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_pass")
5454
}
5555

56+
func TestTrainWithModelRepoImage(t *testing.T) {
57+
a := assert.New(t)
58+
tir := ir.MockTrainStmt(false)
59+
tir.ModelImage = "myRepo/MyDNNClassifier:v1.0"
60+
code, err := Train(tir, mockSession())
61+
a.NoError(err)
62+
r, _ := regexp.Compile(`model_repo_image="(.*)"`)
63+
a.Equal(r.FindStringSubmatch(code)[1], tir.ModelImage)
64+
}
65+
5666
func TestTrainWithOptimizer(t *testing.T) {
5767
a := assert.New(t)
5868
tir := ir.MockTrainStmt(false)

pkg/codegen/tensorflow/template_train.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type trainFiller struct {
3232
IsPAI bool
3333
PAITrainTable string
3434
PAIValidateTable string
35+
ModelRepoImage string
3536
}
3637

3738
const tfTrainTemplateText = `
@@ -138,5 +139,7 @@ train(datasource="{{.DataSource}}",
138139
load_pretrained_model="{{.LoadPreTrainedModel}}" == "true",
139140
is_pai="{{.IsPAI}}" == "true",
140141
pai_table="{{.PAITrainTable}}",
141-
pai_val_table="{{.PAIValidateTable}}")
142+
pai_val_table="{{.PAIValidateTable}}",
143+
feature_columns_code=feature_columns_code,
144+
model_repo_image="{{.ModelRepoImage}}")
142145
`

pkg/codegen/xgboost/codegen.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ import (
1818
"encoding/json"
1919
"fmt"
2020
"regexp"
21-
"sqlflow.org/sqlflow/pkg/codegen"
2221
"strings"
2322

23+
"sqlflow.org/sqlflow/pkg/codegen"
24+
2425
"sqlflow.org/sqlflow/pkg/attribute"
2526
tf "sqlflow.org/sqlflow/pkg/codegen/tensorflow"
2627
"sqlflow.org/sqlflow/pkg/ir"
@@ -336,7 +337,9 @@ func newTrainFiller(trainStmt *ir.TrainStmt, session *pb.Session, ossURIToSave,
336337
LoadPreTrainedModel: trainStmt.PreTrainedModel != "",
337338
IsPAI: tf.IsPAI(),
338339
PAITrainTable: paiTrainTable,
339-
PAIValidateTable: paiValidateTable}, nil
340+
PAIValidateTable: paiValidateTable,
341+
ModelRepoImage: trainStmt.ModelImage,
342+
}, nil
340343
}
341344

342345
// Pred generates a Python program for predict a xgboost model.

pkg/codegen/xgboost/codegen_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,21 @@ func TestResolveModelParams(t *testing.T) {
8282
a.Equal(objectiveName[i], tir.Attributes["objective"])
8383
}
8484
}
85+
86+
func TestTrainWithModelRepoImage(t *testing.T) {
87+
a := assert.New(t)
88+
tir := ir.MockTrainStmt(true)
89+
a.NoError(InitializeAttributes(tir))
90+
tir.ModelImage = "myRepo/MyXGBClassifier:v1.0"
91+
code, err := Train(tir, mockSession())
92+
a.NoError(err)
93+
r, _ := regexp.Compile(`model_repo_image="(.*)"`)
94+
a.Equal(r.FindStringSubmatch(code)[1], tir.ModelImage)
95+
96+
// dist train
97+
code, err = DistTrain(tir, mockSession(), 2, "", "")
98+
a.NoError(err)
99+
r, _ = regexp.Compile(`model_repo_image="(.*)"`)
100+
a.Equal(r.FindStringSubmatch(code)[1], tir.ModelImage)
101+
102+
}

pkg/codegen/xgboost/template_train.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type trainFiller struct {
3434
IsPAI bool
3535
PAITrainTable string
3636
PAIValidateTable string
37+
ModelRepoImage string
3738
}
3839

3940
const trainTemplateText = `
@@ -81,7 +82,8 @@ train(datasource='''{{.DataSource}}''',
8182
pai_validate_table="{{.PAIValidateTable}}",
8283
oss_model_dir="{{.OSSModelDirToSave}}",
8384
transform_fn=transform_fn,
84-
feature_column_code='''{{.FeatureColumnCode}}''')
85+
feature_column_code='''{{.FeatureColumnCode}}''',
86+
model_repo_image="{{.ModelRepoImage}}")
8587
`
8688

8789
const distTrainTemplateText = `
@@ -128,7 +130,8 @@ dist_train(flags=FLAGS,
128130
pai_validate_table="{{.PAIValidateTable}}",
129131
oss_model_dir="{{.OSSModelDirToSave}}",
130132
transform_fn=transform_fn,
131-
feature_column_code='''{{.FeatureColumnCode}}''')
133+
feature_column_code='''{{.FeatureColumnCode}}''',
134+
model_repo_image="{{.ModelRepoImage}}")
132135
`
133136

134137
var trainTemplate = template.Must(template.New("Train").Parse(trainTemplateText))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License
13+
14+
import copy
15+
import json
16+
17+
18+
def collect_model_metadata(select, validate_select, estimator, attributes,
19+
feature_columns, field_descs, label, evaluation,
20+
model_repo_image):
21+
""" collect kinds of model metadata and put them in a dict """
22+
metadata = dict(locals())
23+
attr_copy = copy.deepcopy(attributes)
24+
for (k, v) in attr_copy.items():
25+
try:
26+
json.dumps(v)
27+
except:
28+
attr_copy[k] = str(v)
29+
metadata['attributes'] = attr_copy
30+
return metadata
31+
32+
33+
def save_model_metadata(path, metadata):
34+
"""save_model_metdata saves given params into 'path'"""
35+
with open(path, mode="w") as meta_file:
36+
meta_file.write(json.dumps(metadata, indent=2))
37+
38+
39+
def load_model_metadata(path):
40+
"""load_model_metadata load metadata from given 'path'"""
41+
with open(path, mode="r") as meta_file:
42+
lines = meta_file.readlines()
43+
return json.loads(lines)

python/sqlflow_submitter/tensorflow/train.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DNNLinearCombinedRegressor, DNNRegressor,
3232
LinearClassifier, LinearRegressor)
3333

34+
from ..model_metadata import collect_model_metadata
3435
from .get_tf_version import tf_is_version2
3536
from .input_fn import get_dataset_fn
3637
from .pai_distributed import define_tf_flags, set_oss_environs
@@ -70,7 +71,13 @@ def train(datasource,
7071
load_pretrained_model=False,
7172
is_pai=False,
7273
pai_table="",
73-
pai_val_table=""):
74+
pai_val_table="",
75+
feature_columns_code="",
76+
model_repo_image=""):
77+
model_meta = collect_model_metadata(select, validation_select,
78+
estimator_string, model_params,
79+
feature_columns_code, feature_metas,
80+
label_meta, None, model_repo_image)
7481
# import custom model package
7582
sqlflow_submitter.import_model_def(estimator_string, globals())
7683
estimator = eval(estimator_string)
@@ -123,13 +130,16 @@ def train(datasource,
123130
keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
124131
train_dataset_fn, val_dataset_fn, label_meta,
125132
epoch, verbose, validation_metrics,
126-
validation_steps, load_pretrained_model)
133+
validation_steps, load_pretrained_model,
134+
model_meta)
127135
else:
128-
estimator_train_and_save(
129-
estimator, model_params, save, is_pai, FLAGS, train_dataset_fn,
130-
val_dataset_fn, log_every_n_iter, max_steps,
131-
validation_start_delay_secs, validation_throttle_secs,
132-
save_checkpoints_steps, validation_metrics, load_pretrained_model)
136+
estimator_train_and_save(estimator, model_params, save, is_pai, FLAGS,
137+
train_dataset_fn, val_dataset_fn,
138+
log_every_n_iter, max_steps,
139+
validation_start_delay_secs,
140+
validation_throttle_secs,
141+
save_checkpoints_steps, validation_metrics,
142+
load_pretrained_model, model_meta)
133143

134144
# remove cache files
135145
any(map(os.remove, glob.glob('cache_train.*')))

python/sqlflow_submitter/tensorflow/train_estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from os import path
15+
1416
import tensorflow as tf
1517

18+
from ..model_metadata import save_model_metadata
1619
from . import metrics
1720
from .diag import check_and_load_estimator
1821
from .get_tf_version import tf_is_version2
@@ -25,7 +28,7 @@ def estimator_train_and_save(estimator, model_params, save, is_pai, FLAGS,
2528
log_every_n_iter, train_max_steps,
2629
eval_start_delay_secs, eval_throttle_secs,
2730
save_checkpoints_steps, metric_names,
28-
load_pretrained_model):
31+
load_pretrained_model, model_meta):
2932
print("Start training using estimator model...")
3033

3134
is_distributed = False
@@ -73,9 +76,13 @@ def estimator_train_and_save(estimator, model_params, save, is_pai, FLAGS,
7376
tf.feature_column.make_parse_example_spec(all_feature_columns))
7477
export_path = classifier.export_saved_model(save, serving_input_fn)
7578
# write the path under current directory
79+
export_path_str = str(export_path.decode("utf-8"))
7680
with open("exported_path", "w") as fn:
77-
fn.write(str(export_path.decode("utf-8")))
78-
print("Done training, model exported to: %s" % export_path)
81+
fn.write(export_path_str)
82+
# write model metadata to model_meta.json
83+
save_model_metadata(path.join(export_path_str, "model_meta.json"),
84+
model_meta)
85+
print("Done training, model exported to: %s" % export_path_str)
7986

8087

8188
def estimator_train_compiled(estimator, is_pai, FLAGS, train_dataset_fn,

python/sqlflow_submitter/tensorflow/train_keras.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
import inspect
1515
import sys
1616
import warnings
17+
from os import path
1718

1819
import six
1920
import tensorflow as tf
2021
from sqlflow_submitter.pai import model
2122
from sqlflow_submitter.seeding import get_tf_random_seed
2223

24+
from ..model_metadata import save_model_metadata
2325
from . import metrics
2426
from .diag import check_and_load_estimator
2527
from .get_tf_version import tf_is_version2
@@ -33,7 +35,7 @@
3335
def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
3436
train_dataset_fn, val_dataset_fn, label_meta, epochs,
3537
verbose, metric_names, validation_steps,
36-
load_pretrained_model):
38+
load_pretrained_model, model_meta):
3739
print("Start training using keras model...")
3840
# remove optimizer param from model_params and use it when call "compile()"
3941
optimizer = None
@@ -171,10 +173,15 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
171173
tf.feature_column.make_parse_example_spec(all_feature_columns))
172174
export_path = keras_estimator.export_saved_model(
173175
save, serving_input_fn)
176+
174177
# write the path under current directory
178+
export_path_str = str(export_path.decode("utf-8"))
175179
with open("exported_path", "w") as fn:
176-
fn.write(str(export_path.decode("utf-8")))
177-
print("Done training, model exported to: %s" % export_path)
180+
fn.write(export_path_str)
181+
# write model metadata to model_meta.json
182+
save_model_metadata(path.join(export_path_str, "model_meta.json"),
183+
model_meta)
184+
print("Done training, model exported to: %s" % export_path_str)
178185
return
179186

180187
if hasattr(classifier, 'sqlflow_train_loop'):
@@ -200,25 +207,29 @@ def keras_train_and_save(estimator, model_params, save, is_pai, FLAGS,
200207
epochs=epochs if epochs else
201208
classifier.default_training_epochs(),
202209
verbose=verbose)
203-
train_keys = []
204-
val_keys = []
210+
train_metrics = dict()
211+
val_metrics = dict()
205212
for k in history.history.keys():
206213
if k.startswith("val_"):
207-
val_keys.append(k)
214+
val_metrics[k] = float(history.history[k][-1])
208215
else:
209-
train_keys.append(k)
216+
train_metrics[k] = float(history.history[k][-1])
210217
print("====== Result for training set: ======")
211-
for k in train_keys:
212-
print("%s: %s" % (k, history.history[k][-1]))
218+
for k, v in train_metrics.items():
219+
print("%s: %s" % (k, v))
213220
print("====== Result for validation set: ======")
214-
for k in val_keys:
215-
print("%s: %s" % (k, history.history[k][-1]))
221+
for k, v in val_metrics.items():
222+
print("%s: %s" % (k, v))
223+
model_meta["evaluation"] = val_metrics
216224

217225
try:
218226
classifier.save_weights(save, save_format="h5")
227+
# write model metadata to model_meta.json
228+
save_model_metadata("model_meta.json", model_meta)
219229
if is_pai:
220230
print("saving keras model to: %s" % FLAGS.sqlflow_oss_modeldir)
221231
model.save_file(FLAGS.sqlflow_oss_modeldir, save)
232+
model.save_file(FLAGS.sqlflow_oss_modeldir, "model_meta.json")
222233
except:
223234
if has_none_optimizer:
224235
warnings.warn("Saving model with None optimizer fails")

0 commit comments

Comments
 (0)