Skip to content

Commit 67b5833

Browse files
authored
refine codes in runtime (#2764)
1 parent c7228e1 commit 67b5833

8 files changed

Lines changed: 88 additions & 70 deletions

File tree

python/runtime/db.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222
def parseMySQLDSN(dsn):
2323
# [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
24-
user, passwd, host, port, database, config_str = re.findall(
25-
"^(\w*):(\w*)@tcp\(([.a-zA-Z0-9\-]*):([0-9]*)\)/(\w*)(\?.*)?$", dsn)[0]
24+
pattern = "^(\w*):(\w*)@tcp\(([.a-zA-Z0-9\-]*):([0-9]*)\)/(\w*)(\?.*)?$" # noqa: W605, E501
25+
found_result = re.findall(pattern, dsn)
26+
user, passwd, host, port, database, config_str = found_result[0]
2627
config = {}
2728
if len(config_str) > 1:
2829
for c in config_str[1:].split("&"):
@@ -34,7 +35,7 @@ def parseMySQLDSN(dsn):
3435
def parseHiveDSN(dsn):
3536
# usr:pswd@hiveserver:10000/mydb?auth=PLAIN&session.mapreduce_job_queuename=mr
3637
user_passwd, address_database, config_str = re.findall(
37-
"^(.*)@([.a-zA-Z0-9/:_]*)(\?.*)?", dsn)[0]
38+
"^(.*)@([.a-zA-Z0-9/:_]*)(\?.*)?", dsn)[0] # noqa: W605
3839
user, passwd = user_passwd.split(":")
3940
if len(address_database.split("/")) > 1:
4041
address, database = address_database.split("/")
@@ -60,7 +61,7 @@ def parseHiveDSN(dsn):
6061
def parseMaxComputeDSN(dsn):
6162
# access_id:access_key@service.com/api?curr_project=test_ci&scheme=http
6263
user_passwd, address, config_str = re.findall(
63-
"^(.*)@([-.a-zA-Z0-9/]*)(\?.*)?", dsn)[0]
64+
"^(.*)@([-.a-zA-Z0-9/]*)(\?.*)?", dsn)[0] # noqa: W605
6465
user, passwd = user_passwd.split(":")
6566
config = {}
6667
if len(config_str) > 1:
@@ -146,6 +147,9 @@ def connect(driver,
146147
return conn
147148

148149

150+
INT64_TYPE = long if six.PY2 else int # noqa: F821
151+
152+
149153
def read_feature(raw_val, feature_spec, feature_name):
150154
# FIXME(typhoonzero): Should use correct dtype here.
151155
if feature_spec["is_sparse"]:
@@ -181,7 +185,7 @@ def read_feature(raw_val, feature_spec, feature_name):
181185
elif feature_spec["dtype"] == "float32":
182186
return float(raw_val),
183187
elif feature_spec["dtype"] == "int64":
184-
int_raw_val = long(raw_val) if six.PY2 else int(raw_val)
188+
int_raw_val = INT64_TYPE(raw_val)
185189
return int_raw_val,
186190
elif feature_spec["dtype"] == "string":
187191
return str(raw_val),
@@ -202,8 +206,8 @@ def limit_select(select, n):
202206
n (int): the limited row number to query.
203207
204208
Returns:
205-
If n >= 0, return a new SQL statement which would query n row(s) at most.
206-
If n < 0, return the original SQL statement.
209+
If n >= 0, return a new SQL statement which would query n row(s)
210+
at most. If n < 0, return the original SQL statement.
207211
"""
208212
if n < 0:
209213
return select
@@ -224,7 +228,8 @@ def replace_limit_num(matched_limit):
224228

225229
try:
226230
import MySQLdb.constants.FIELD_TYPE as MYSQL_FIELD_TYPE
227-
# Refer to http://mysql-python.sourceforge.net/MySQLdb-1.2.2/public/MySQLdb.constants.FIELD_TYPE-module.html
231+
# Refer to
232+
# http://mysql-python.sourceforge.net/MySQLdb-1.2.2/public/MySQLdb.constants.FIELD_TYPE-module.html # noqa: E501
228233
MYSQL_FIELD_TYPE_DICT = {
229234
MYSQL_FIELD_TYPE.TINY: "TINYINT", # 1
230235
MYSQL_FIELD_TYPE.LONG: "INT", # 3
@@ -236,7 +241,7 @@ def replace_limit_num(matched_limit):
236241
MYSQL_FIELD_TYPE.VAR_STRING: "VARCHAR", # 253
237242
MYSQL_FIELD_TYPE.STRING: "CHAR", # 254
238243
}
239-
except:
244+
except: # noqa: E722
240245
MYSQL_FIELD_TYPE_DICT = {}
241246

242247

@@ -369,7 +374,8 @@ def reader():
369374
label_idx = reader.field_names.index(
370375
label_meta["feature_name"])
371376
except ValueError:
372-
# NOTE(typhoonzero): For clustering model, label_column_name may not in reader.field_names when predicting.
377+
# NOTE(typhoonzero): For clustering model, label_column_name
378+
# may not in reader.field_names when predicting.
373379
label_idx = None
374380
else:
375381
label_idx = None
@@ -378,12 +384,14 @@ def reader():
378384
rows = cursor.fetchmany(size=fetch_size)
379385
if not rows:
380386
break
381-
# NOTE: keep the connection while training or connection will lost if no activities appear.
387+
# NOTE: keep the connection while training or connection will lost
388+
# if no activities appear.
382389
if driver == "mysql":
383390
conn.ping(True)
384391
for row in rows:
385-
# NOTE: If there is no label clause in the extended SQL, the default label value would
386-
# be -1, the Model implementation can determine use it or not.
392+
# NOTE: If there is no label clause in the extended SQL, the
393+
# default label value would be -1, the Model implementation
394+
# can determine use it or not.
387395
label = row[label_idx] if label_idx is not None else -1
388396
if label_meta and label_meta["delimiter"] != "":
389397
if label_meta["dtype"] == "float32":
@@ -428,7 +436,7 @@ def reader():
428436
while True:
429437
try:
430438
row = pai_reader.read(num_records=1)[0]
431-
except:
439+
except: # noqa: E722
432440
pai_reader.close()
433441
break
434442

@@ -509,7 +517,7 @@ def execute(conn, sql_stmt):
509517
Args:
510518
conn: a database connection, this function will leave it open
511519
sql_stmt: the sql statement to execute
512-
520+
513521
Returns:
514522
True on success and False on failure
515523
"""
@@ -522,7 +530,7 @@ def execute(conn, sql_stmt):
522530
cur.execute(sql_stmt)
523531
conn.commit()
524532
return True
525-
except:
533+
except: # noqa: E722
526534
return False
527535
finally:
528536
cur.close()

python/runtime/db_test.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import runtime.testing as testing
20-
from odps import ODPS, tunnel
20+
from odps import tunnel
2121
from runtime.db import (MYSQL_FIELD_TYPE_DICT, buffered_db_writer, connect,
2222
connect_with_data_source, db_generator,
2323
get_table_schema, limit_select, parseHiveDSN,
@@ -53,7 +53,7 @@ def execute(driver, conn, statement):
5353
try:
5454
rows = cursor.fetchall()
5555
field_columns = list(map(list, zip(*rows))) if len(rows) > 0 else None
56-
except:
56+
except: # noqa: E722
5757
field_columns = None
5858

5959
return field_names, field_columns
@@ -62,7 +62,9 @@ def execute(driver, conn, statement):
6262
class TestDB(TestCase):
6363

6464
create_statement = "create table test_db (features text, label int)"
65-
hive_create_statement = 'create table test_db (features string, label int) ROW FORMAT DELIMITED FIELDS TERMINATED BY "\001"'
65+
hive_create_statement = 'create table test_db (features string, ' \
66+
'label int) ROW FORMAT DELIMITED FIELDS ' \
67+
'TERMINATED BY "\001"'
6668
select_statement = "select * from test_db"
6769
drop_statement = "drop table if exists test_db"
6870

@@ -114,7 +116,8 @@ def _do_test_hive_specified_db(self,
114116
hdfs_namenode_addr="",
115117
hive_location=""):
116118
create_db = '''create database if not exists test_db'''
117-
create_tbl = '''create table test_db.tbl (features string, label int) ROW FORMAT DELIMITED FIELDS TERMINATED BY "\001"'''
119+
create_tbl = '''create table test_db.tbl (features string, label int)
120+
ROW FORMAT DELIMITED FIELDS TERMINATED BY "\001"'''
118121
drop_tbl = '''drop table if exists test_db.tbl'''
119122
select_tbl = '''select * from test_db.tbl'''
120123
table_schema = ["label", "features"]
@@ -173,9 +176,11 @@ def _do_test(self, driver, conn, hdfs_namenode_addr="", hive_location=""):
173176

174177

175178
class TestGenerator(TestCase):
176-
create_statement = "create table test_table_float_fea (features float, label int)"
179+
create_statement = "create table test_table_float_fea " \
180+
"(features float, label int)"
177181
drop_statement = "drop table if exists test_table_float_fea"
178-
insert_statement = "insert into test_table_float_fea (features,label) values(1.0, 0), (2.0, 1)"
182+
insert_statement = "insert into test_table_float_fea (features,label)" \
183+
" values(1.0, 0), (2.0, 1)"
179184

180185
@unittest.skipUnless(testing.get_driver() == "mysql",
181186
"skip non mysql tests")
@@ -241,29 +246,26 @@ def test_parse_hive_dsn(self):
241246
("usr", "pswd", "hiveserver", "1000", "mydb", "PLAIN", {
242247
"mapreduce_job_quenename": "mr"
243248
}),
244-
parseHiveDSN(
245-
"usr:pswd@hiveserver:1000/mydb?auth=PLAIN&session.mapreduce_job_quenename=mr"
246-
))
249+
parseHiveDSN("usr:pswd@hiveserver:1000/mydb?auth=PLAIN&"
250+
"session.mapreduce_job_quenename=mr"))
247251
self.assertEqual(
248252
("usr", "pswd", "hiveserver", "1000", "my_db", "PLAIN", {
249253
"mapreduce_job_quenename": "mr"
250254
}),
251-
parseHiveDSN(
252-
"usr:pswd@hiveserver:1000/my_db?auth=PLAIN&session.mapreduce_job_quenename=mr"
253-
))
255+
parseHiveDSN("usr:pswd@hiveserver:1000/my_db?auth=PLAIN&"
256+
"session.mapreduce_job_quenename=mr"))
254257
self.assertEqual(
255258
("root", "root", "127.0.0.1", None, "mnist", "PLAIN", {}),
256259
parseHiveDSN("root:root@127.0.0.1/mnist?auth=PLAIN"))
257260
self.assertEqual(("root", "root", "127.0.0.1", None, None, "", {}),
258261
parseHiveDSN("root:root@127.0.0.1"))
259262

260263
def test_parse_maxcompute_dsn(self):
261-
self.assertEqual(
262-
("access_id", "access_key", "http://maxcompute-service.com/api",
263-
"test_ci"),
264-
parseMaxComputeDSN(
265-
"access_id:access_key@maxcompute-service.com/api?curr_project=test_ci&scheme=http"
266-
))
264+
self.assertEqual(("access_id", "access_key",
265+
"http://maxcompute-service.com/api", "test_ci"),
266+
parseMaxComputeDSN(
267+
"access_id:access_key@maxcompute-service.com/api?"
268+
"curr_project=test_ci&scheme=http"))
267269

268270
def test_kv_feature_column(self):
269271
feature_spec = {
@@ -300,8 +302,8 @@ def test_get_table_schema(self):
300302

301303
schema = selected_columns_and_types(
302304
conn,
303-
"SELECT sepal_length, petal_width * 2.3 new_petal_width, class FROM iris.train"
304-
)
305+
"SELECT sepal_length, petal_width * 2.3 new_petal_width, "
306+
"class FROM iris.train")
305307
expect = [
306308
("sepal_length", "FLOAT"),
307309
("new_petal_width", "DOUBLE"),
@@ -321,8 +323,8 @@ def test_get_table_schema(self):
321323

322324
schema = selected_columns_and_types(
323325
conn,
324-
"SELECT sepal_length, petal_width * 2.3 AS new_petal_width, class FROM iris.train"
325-
)
326+
"SELECT sepal_length, petal_width * 2.3 AS new_petal_width, "
327+
"class FROM iris.train")
326328
expect = [
327329
("sepal_length", "FLOAT"),
328330
("new_petal_width", "FLOAT"),
@@ -344,8 +346,8 @@ def test_get_table_schema(self):
344346

345347
schema = selected_columns_and_types(
346348
conn,
347-
"SELECT sepal_length, petal_width * 2.3 new_petal_width, class FROM %s"
348-
% table)
349+
"SELECT sepal_length, petal_width * 2.3 new_petal_width, "
350+
"class FROM %s" % table)
349351
expect = [
350352
("sepal_length", "DOUBLE"),
351353
("new_petal_width", "DOUBLE"),
@@ -367,7 +369,8 @@ def test_field_type(self):
367369

368370
table_name = "iris.test_mysql_field_type_table"
369371
drop_table_sql = "DROP TABLE IF EXISTS %s" % table_name
370-
create_table_sql = "CREATE TABLE IF NOT EXISTS " + table_name + "(a %s)"
372+
create_table_sql = "CREATE TABLE IF NOT EXISTS " + \
373+
table_name + "(a %s)"
371374
select_sql = "SELECT * FROM %s" % table_name
372375

373376
for int_type, str_type in MYSQL_FIELD_TYPE_DICT.items():

python/runtime/db_writer/pai_maxcompute.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,12 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
try:
15-
import paiio
16-
except:
17-
pass
18-
import tensorflow as tf
19-
from odps import ODPS, tunnel
20-
2114
from .base import BufferedDBWriter
2215

2316

2417
class PAIMaxComputeDBWriter(BufferedDBWriter):
2518
def __init__(self, table_name, table_schema, buff_size):
19+
import paiio
2620
super(PAIMaxComputeDBWriter, self).__init__(None, table_name,
2721
table_schema, buff_size)
2822
table_name_parts = table_name.split(".")

python/runtime/diagnostics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def load_pretrained_model_estimator(estimator,
2828
estimator_func = estimator.__init__ if inspect.isclass(
2929
estimator) else estimator
3030
estimator_spec = inspect.getargspec(estimator_func)
31-
# The constructor of Estimator contains named parameter "warm_start_from"
31+
# The constructor of Estimator contains named parameter
32+
# "warm_start_from"
3233
warm_start_from_key = "warm_start_from"
3334
if warm_start_from_key in estimator_spec.args:
3435
warm_start_from = os.path.abspath(warm_start_from)

python/runtime/explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import matplotlib.pyplot as plt
1919
from runtime.oss import copyfileobj
2020

21-
# TODO(shendiaomo): extract common code from tensorflow/explain.py and xgboost/explain.py
21+
# TODO(shendiaomo): extract common code from tensorflow/explain.py
22+
# and xgboost/explain.py
2223
# TODO(shendiaomo): add a unit test for this file later
2324

2425

python/runtime/maxcompute.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,23 @@ def reader():
6565
label_idx = reader.field_names.index(
6666
label_meta["feature_name"])
6767
except ValueError:
68-
# NOTE(typhoonzero): For clustering model, label_column_name may not in reader.field_names when predicting.
68+
# NOTE(typhoonzero): For clustering model,
69+
# label_column_name may not in reader.field_names
70+
# when predicting.
6971
label_idx = None
7072
else:
7173
label_idx = None
7274

7375
i = 0
7476
while i < r.count:
75-
expected = r.count - i if r.count - i < fetch_size else fetch_size
77+
if r.count - i < fetch_size:
78+
expected = r.count - i
79+
else:
80+
expected = fetch_size
7681
for row in [[v[1] for v in rec] for rec in r[i:i + expected]]:
77-
# NOTE: If there is no label clause in the extended SQL, the default label value would
78-
# be -1, the Model implementation can determine use it or not.
82+
# NOTE: If there is no label clause in the extended SQL,
83+
# the default label value would be -1, the Model
84+
# implementation can determine use it or not.
7985
label = row[label_idx] if label_idx is not None else None
8086
if label_meta and label_meta["delimiter"] != "":
8187
if label_meta["dtype"] == "float32":

python/runtime/model_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def collect_model_metadata(original_sql, select, validate_select, class_name,
2424
for (k, v) in attr_copy.items():
2525
try:
2626
json.dumps(v)
27-
except:
27+
except: # noqa: E722
2828
attr_copy[k] = str(v)
2929
metadata['attributes'] = attr_copy
3030
return metadata

0 commit comments

Comments
 (0)