Skip to content

Commit 5c412b4

Browse files
authored
refine feature and model python code by flake8 (#2761)
1 parent b539fdd commit 5c412b4

8 files changed

Lines changed: 60 additions & 47 deletions

File tree

python/runtime/feature/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.feature.derivation import infer_feature_columns
14+
from runtime.feature.derivation import infer_feature_columns # noqa: F401

python/runtime/feature/column.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,13 @@ class CrossColumn(CategoryColumn):
179179
CrossColumn represents a crossed feature column.
180180
181181
Args:
182-
keys (str|NumericColumn): the underlying feature column name or NumericColumn object.
182+
keys (str|NumericColumn): the underlying feature column name or
183+
NumericColumn object.
183184
hash_bucket_size (int): the bucket size for hashing.
184185
"""
185186
def __init__(self, keys, hash_bucket_size):
186187
for k in keys:
187-
assert isinstance(k, six.string_types) or isinstance(k, NumericColumn), \
188+
assert isinstance(k, (six.string_types, NumericColumn)), \
188189
"keys of CROSS must be of either string or numeric type"
189190

190191
self.keys = keys
@@ -217,12 +218,12 @@ class EmbeddingColumn(FeatureColumn):
217218
Args:
218219
category_column (CategoryColumn): the underlying CategoryColumn object.
219220
dimension (int): the dimension of the embedding.
220-
combiner (str): how to reduce if there are multiple entries in a single row.
221-
Currently 'mean', 'sqrtn' and 'sum' are supported.
221+
combiner (str): how to reduce if there are multiple entries in a single
222+
row. Currently 'mean', 'sqrtn' and 'sum' are supported.
222223
initializer (str): the initializer of the embedding table.
223224
name (str): only used when category_column=None. In this case, the
224-
category_column would be filled automaticaly in the feature derivation
225-
stage.
225+
category_column would be filled automaticaly in the feature
226+
derivation stage.
226227
"""
227228
def __init__(self,
228229
category_column=None,
@@ -266,8 +267,8 @@ class IndicatorColumn(FeatureColumn):
266267
Args:
267268
category_column (CategoryColumn): the underlying CategoryColumn object.
268269
name (str): only used when category_column=None. In this case, the
269-
category_column would be filled automaticaly in the feature derivation
270-
stage.
270+
category_column would be filled automaticaly in the feature
271+
derivation stage.
271272
"""
272273
def __init__(self, category_column=None, name=""):
273274
if category_column is not None:

python/runtime/feature/derivation.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import numpy as np
2121
import six
2222
from runtime.feature.column import (CategoryIDColumn, EmbeddingColumn,
23-
FeatureColumn, IndicatorColumn,
24-
NumericColumn)
23+
IndicatorColumn, NumericColumn)
2524
from runtime.feature.field_desc import DataFormat, DataType, FieldDesc
2625
from runtime.verifier import fetch_samples
2726

@@ -38,8 +37,8 @@ def init_column_map(target_fc_map, fc):
3837
Returns:
3938
None.
4039
"""
41-
if isinstance(fc, (EmbeddingColumn, IndicatorColumn)) and \
42-
len(fc.get_field_desc()) == 0:
40+
if isinstance(fc, (EmbeddingColumn, IndicatorColumn)) \
41+
and len(fc.get_field_desc()) == 0:
4342
if fc.name not in target_fc_map:
4443
target_fc_map[fc.name] = []
4544

@@ -129,7 +128,7 @@ def new_default_field_desc(name):
129128
BLANK_PATTERN = re.compile("\\s+")
130129

131130
# The Python 2/3 int64 type
132-
INT64_TYPE = long if six.PY2 else int
131+
INT64_TYPE = long if six.PY2 else int # noqa: F821
133132

134133

135134
def infer_string_data_format(str_data):
@@ -165,7 +164,8 @@ def fill_csv_field_desc(cell, field_desc):
165164
"""
166165
values = cell.split(",")
167166
if field_desc.is_sparse:
168-
assert field_desc.shape is not None, "the shape of CSV format data must be given"
167+
assert field_desc.shape is not None, \
168+
"the shape of CSV format data must be given"
169169
else:
170170
if field_desc.shape is None:
171171
field_desc.shape = [len(values)]
@@ -174,8 +174,8 @@ def fill_csv_field_desc(cell, field_desc):
174174
if np.prod(field_desc.shape) != len(values):
175175
if size > 1:
176176
raise ValueError(
177-
"column %s should be csv format dense tensor of %d element(s), but got %d element(s)"
178-
%
177+
"column %s should be csv format dense tensor "
178+
"of %d element(s), but got %d element(s)" %
179179
(field_desc.name, np.prod(field_desc.shape), len(values)))
180180

181181
field_desc.shape = [len(values)]
@@ -356,10 +356,13 @@ def update_feature_column(fc, fd_map):
356356
raise ValueError("column not found or inferred: %s" % fc.name)
357357

358358
# FIXME(typhoonzero): when to use sequence_category_id_column?
359-
# if column fieldDesc is SPARSE, the sparse shape should be in cs.Shape[0]
359+
# if column fieldDesc is SPARSE, the sparse shape should
360+
# be in cs.Shape[0]
360361
bucket_size = field_desc.shape[0]
361362
if not field_desc.is_sparse:
362-
assert field_desc.max_id > 0, "use dense column on embedding column but did not got a correct MaxID"
363+
assert field_desc.max_id > 0, \
364+
"use dense column on embedding column " \
365+
"but did not got a correct MaxID"
363366
bucket_size = field_desc.max_id + 1
364367

365368
fc.category_column = CategoryIDColumn(field_desc, bucket_size)
@@ -370,8 +373,10 @@ def update_feature_column(fc, fd_map):
370373
if field_desc is None:
371374
raise ValueError("column not found or inferred: %s" % fc.name)
372375

373-
assert field_desc.is_sparse, "cannot use sparse column with indicator column"
374-
assert field_desc.max_id > 0, "use indicator column but did not got a correct MaxID"
376+
assert field_desc.is_sparse, \
377+
"cannot use sparse column with indicator column"
378+
assert field_desc.max_id > 0, \
379+
"use indicator column but did not got a correct MaxID"
375380
bucket_size = field_desc.max_id + 1
376381
fc.category_column = CategoryIDColumn(field_desc, bucket_size)
377382

@@ -392,7 +397,8 @@ def new_feature_column(field_desc):
392397
else:
393398
category_column = CategoryIDColumn(field_desc,
394399
len(field_desc.vocabulary))
395-
# NOTE(typhoonzero): a default embedding size of 128 is enough for most cases.
400+
# NOTE(typhoonzero): a default embedding size of 128 is enough
401+
# for most cases.
396402
embedding = EmbeddingColumn(category_column=category_column,
397403
dimension=128,
398404
combiner="sum")
@@ -406,7 +412,8 @@ def derive_feature_columns(targets, fc_map, fd_map, selected_field_names,
406412
Derive the FeatureColumn.
407413
408414
Args:
409-
targets (list[str]): the feature column targets, e.g. "feature_columns".
415+
targets (list[str]): the feature column targets,
416+
e.g. "feature_columns".
410417
fc_map (dict[str -> dict[str -> list[FeatureColumn]]]): a FeatureColumn
411418
map, where the key of the outer dict is the target name, e.g.
412419
"feature_columns", and the key of the inner dict is the field name.
@@ -439,7 +446,8 @@ def derive_feature_columns(targets, fc_map, fd_map, selected_field_names,
439446
match_field_name = None
440447
for selected_field_name in selected_field_names:
441448
if field_pattern.fullmatch(selected_field_name):
442-
assert match_field_name is None, "%s matches duplicate fields" % field_name
449+
assert match_field_name is None, \
450+
"%s matches duplicate fields" % field_name
443451
match_field_name = selected_field_name
444452

445453
if match_field_name is None:
@@ -464,8 +472,8 @@ def derive_feature_columns(targets, fc_map, fd_map, selected_field_names,
464472
update_feature_column(fc, fd_map)
465473
else:
466474
if len(fc_map) > 1:
467-
# if column clause have more than one target, each target should specify the
468-
# full list of the columns to use.
475+
# if column clause have more than one target, each target
476+
# should specify the full list of the columns to use.
469477
continue
470478

471479
field_desc = fd_map[selected_field_name]
@@ -479,13 +487,13 @@ def derive_feature_columns(targets, fc_map, fd_map, selected_field_names,
479487
fc_target_map.update(new_fc_target_map)
480488

481489

482-
def update_ir_feature_column_map_by_derived_feature_column_map(
483-
features, fc_map, selected_field_names, label_name):
490+
def update_ir_feature_columns(features, fc_map, selected_field_names,
491+
label_name):
484492
"""
485493
Update the IR FeatureColumn map `features` by the derived FeatureColumn map
486-
`fc_map` . If any FeatureColumn inside `fc_map` does not exist in `features`,
487-
it would be added to `features` . Notice that `features` is not updated
488-
in-place, and we would return a new updated IR FeatureColumn map in
494+
`fc_map` . If any FeatureColumn inside `fc_map` does not exist in
495+
`features`, it would be added to `features` . Notice that `features` is not
496+
updated in-place, and we would return a new updated IR FeatureColumn map in
489497
this method.
490498
491499
Args:
@@ -542,9 +550,8 @@ def update_ir_feature_column_map_by_derived_feature_column_map(
542550
break
543551

544552
if not found:
545-
raise ValueError(
546-
"some feature column is missing in the derivation stage"
547-
)
553+
raise ValueError("some feature column is missing in the "
554+
"derivation stage")
548555

549556
sorted_pos = sorted(range(len(indices)), key=lambda k: indices[k])
550557
multi_fd_fcs = [multi_fd_fcs[i] for i in sorted_pos]
@@ -572,9 +579,11 @@ def derive_label(label, fd_map):
572579
return # NOTE: clustering model may not specify Label
573580

574581
label_field_desc = fd_map[label_name]
575-
assert label_field_desc is not None, "deriveLabel: LABEL COLUMN '%s' not found" % label_name
582+
assert label_field_desc is not None, \
583+
"deriveLabel: LABEL COLUMN '%s' not found" % label_name
576584

577-
# use shape [] if label shape is [1] for Tensorflow scalar label shape should be [].
585+
# use shape [] if label shape is [1] for Tensorflow scalar label
586+
# shape should be [].
578587
shape = label_field_desc.shape
579588
if shape is None or (len(shape) == 1 and shape[0] == 1):
580589
label_field_desc.shape = []
@@ -626,7 +635,7 @@ def infer_feature_columns(conn, select, features, label, n=1000):
626635

627636
derive_feature_columns(targets, fc_map, fd_map, selected_field_names,
628637
label_name)
629-
features = update_ir_feature_column_map_by_derived_feature_column_map(
630-
features, fc_map, selected_field_names, label_name)
638+
features = update_ir_feature_columns(features, fc_map,
639+
selected_field_names, label_name)
631640
label = derive_label(label, fd_map)
632641
return features, label

python/runtime/feature/derivation_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
import runtime.feature.derivation as fd
1717
import runtime.testing as testing
1818
from runtime.feature.column import (CategoryIDColumn, CrossColumn,
19-
EmbeddingColumn, IndicatorColumn,
20-
NumericColumn)
19+
EmbeddingColumn, NumericColumn)
2120
from runtime.feature.field_desc import DataFormat, DataType, FieldDesc
2221

2322

@@ -103,7 +102,8 @@ def test_without_cross(self):
103102
label = NumericColumn(
104103
FieldDesc(name="class", dtype=DataType.INT, shape=[1]))
105104

106-
select = "select c1, c2, c3, c4, c5, c6, class from feature_derivation_case.train"
105+
select = "select c1, c2, c3, c4, c5, c6, class " \
106+
"from feature_derivation_case.train"
107107
conn = testing.get_singleton_db_connection()
108108
features, label = fd.infer_feature_columns(conn, select, features,
109109
label)
@@ -218,7 +218,8 @@ def test_with_cross(self):
218218

219219
label = NumericColumn(
220220
FieldDesc(name='class', dtype=DataType.INT, shape=[1]))
221-
select = "select c1, c2, c3, c4, c5, class from feature_derivation_case.train"
221+
select = "select c1, c2, c3, c4, c5, class " \
222+
"from feature_derivation_case.train"
222223

223224
conn = testing.get_singleton_db_connection()
224225
features, label = fd.infer_feature_columns(conn, select, features,

python/runtime/feature/field_desc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ class FieldDesc(object):
5252
PLAIN, CSV, KV. Default PLAIN.
5353
shape (list[int]): the shape of the field data. Default None.
5454
is_sparse (bool): whether the field data is sparse. Default False.
55-
vocabulary (list[str]): the vocabulary used for categorical feature column. Default None.
56-
max_id (int): the maximum id number of the field data. Used in CategoryIDColumn. Default 0.
55+
vocabulary (list[str]): the vocabulary used for categorical
56+
feature column. Default None.
57+
max_id (int): the maximum id number of the field data. Used in
58+
CategoryIDColumn. Default 0.
5759
"""
5860
def __init__(self,
5961
name="",

python/runtime/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.model.model import EstimatorType, Model, load
14+
from runtime.model.model import EstimatorType, Model, load # noqa: F401

python/runtime/model/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class EstimatorType(Enum):
3636
# To stay compitable with old models, we start at 0
3737
TENSORFLOW = 0
3838
XGBOOST = 1
39-
# PAIML is the model type that trained by PAI machine learning algorithm toolkit
39+
# PAIML is the model type that trained by PAI machine learning algorithm
40+
# toolkit
4041
PAIML = 2
4142

4243

python/runtime/model/tar_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# limitations under the License.
1313

1414
import os
15-
import shutil
1615
import tempfile
1716
import unittest
1817

0 commit comments

Comments
 (0)