Skip to content

Commit 9a34bd4

Browse files
authored
Add token validation check to math programming (#2748)
* add error diagnostic to optimize model generation * fix ut bug * polish * polish * follow comments
1 parent f4b4c44 commit 9a34bd4

2 files changed

Lines changed: 138 additions & 3 deletions

File tree

python/runtime/optimize/model_generation.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,56 @@
1212
# limitations under the License.
1313

1414
import copy
15+
import re
1516

1617
__all__ = [
1718
'generate_unique_result_value_name',
1819
'generate_objective_and_constraint_expression',
1920
]
2021

22+
IDENTIFIER_REGEX = re.compile("[_a-zA-Z]\w*")
23+
24+
25+
def assert_are_valid_tokens(columns, tokens, result_value_name, group_by=None):
26+
"""
27+
Check whether the tokens are valid. If the token is inside
28+
columns or result_value_name, or the token is a function-call
29+
identifier, it is valid. Otherwise, raise AssertionError.
30+
31+
Args:
32+
columns (list[str]): the column names of the source table.
33+
tokens (list[str]): the token list.
34+
result_value_name (str): the result value name to be optimized.
35+
group_by (str): the column name to be grouped.
36+
37+
Returns:
38+
None
39+
40+
Raises:
41+
AssertionError if any token is invalid.
42+
"""
43+
valid_columns = [c.lower() for c in columns]
44+
45+
if group_by:
46+
assert group_by.lower(
47+
) in valid_columns, "GROUP BY column %s not found" % group_by
48+
49+
assert tokens, "tokens should not be empty"
50+
51+
valid_columns.append(result_value_name.lower())
52+
53+
for i, token in enumerate(tokens):
54+
if token.lower() in valid_columns:
55+
continue
56+
57+
# If a token is not a function call identifier and not inside
58+
# valid_columns, raise error
59+
if IDENTIFIER_REGEX.fullmatch(token) is None:
60+
continue
61+
62+
assert find_next_non_blank_token(tokens, i + 1) == "(", \
63+
"invalid token %s" % token
64+
2165

2266
def generate_unique_result_value_name(columns, result_value_name, variables):
2367
"""
@@ -184,6 +228,30 @@ def generate_group_by_range_and_index_str(group_by, data_str, value_str,
184228
return outer_range_str, inner_range_str, [value_str, index_str]
185229

186230

231+
def find_next_non_blank_token(tokens, i):
232+
"""
233+
Find next non-blank token after index i (including i).
234+
235+
Args:
236+
tokens (list[str]): a string token list.
237+
i (int): the position to search.
238+
239+
Returns:
240+
If any token is found, return the found token.
241+
Otherwise, return None.
242+
"""
243+
if i < 0:
244+
return None
245+
246+
while i < len(tokens):
247+
if tokens[i].strip():
248+
return tokens[i]
249+
250+
i += 1
251+
252+
return None
253+
254+
187255
def find_prev_non_blank_token(tokens, i):
188256
"""
189257
Find previous non-blank token before index i (including i).
@@ -585,6 +653,9 @@ def generate_objective_and_constraint_expression(columns,
585653
constraint_exprs = []
586654

587655
if objective:
656+
assert_are_valid_tokens(columns=columns,
657+
tokens=objective,
658+
result_value_name=result_value_name)
588659
obj_expr, for_range, iter_vars = generate_objective_or_constraint_expression(
589660
columns=columns,
590661
tokens=objective,
@@ -603,6 +674,10 @@ def generate_objective_and_constraint_expression(columns,
603674
tokens = c.get("tokens")
604675
group_by = c.get("group_by")
605676

677+
assert_are_valid_tokens(columns=columns,
678+
tokens=tokens,
679+
result_value_name=result_value_name,
680+
group_by=group_by)
606681
expr, for_range, iter_vars = generate_objective_or_constraint_expression(
607682
columns=columns,
608683
tokens=tokens,

python/runtime/optimize/model_generation_test.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,68 @@
1818
import pandas as pd
1919
import pyomo.environ as pyomo_env
2020
from runtime.optimize.local import generate_model_with_data_frame, solve_model
21-
from runtime.optimize.model_generation import \
22-
generate_objective_and_constraint_expression
21+
from runtime.optimize.model_generation import (
22+
IDENTIFIER_REGEX, assert_are_valid_tokens,
23+
generate_objective_and_constraint_expression)
24+
25+
26+
class TestAssertValidTokens(unittest.TestCase):
27+
def is_identifier(self, token):
28+
return IDENTIFIER_REGEX.fullmatch(token) is not None
29+
30+
def test_is_identifier(self):
31+
tokens = ['a', '_', 'a123', '__', '_123']
32+
for t in tokens:
33+
self.assertTrue(self.is_identifier(t))
34+
35+
tokens = ['1', '123_', '3def']
36+
for t in tokens:
37+
self.assertFalse(self.is_identifier(t))
38+
39+
def test_assert_valid_tokens(self):
40+
tokens = ['SUM', '(', 'finishing', '*', 'product', ')', '<=', '100']
41+
42+
# valid expression
43+
assert_are_valid_tokens(columns=['finishing', 'product'],
44+
tokens=tokens,
45+
result_value_name='product')
46+
47+
# invalid group_by
48+
with self.assertRaises(AssertionError):
49+
assert_are_valid_tokens(columns=['finishing', 'product'],
50+
tokens=tokens,
51+
result_value_name='product',
52+
group_by='invalid_group_by')
53+
54+
# tokens = None
55+
with self.assertRaises(AssertionError):
56+
assert_are_valid_tokens(columns=['finishing', 'product'],
57+
tokens=None,
58+
result_value_name='product')
59+
60+
# tokens = []
61+
with self.assertRaises(AssertionError):
62+
assert_are_valid_tokens(columns=['finishing', 'product'],
63+
tokens=[],
64+
result_value_name='product')
65+
66+
# tokens not inside columns
67+
tokens = [
68+
'SUM', '(', 'finishing', '*', 'invalid_token', ')', '<=', '100'
69+
]
70+
with self.assertRaises(AssertionError):
71+
assert_are_valid_tokens(columns=['finishing', 'product'],
72+
tokens=tokens,
73+
result_value_name='product')
74+
75+
# tokens not inside columns but equal to result_value_name
76+
# ignore cases
77+
tokens = [
78+
'SUM', '(', 'FinisHing', '*', 'pRoducT_VaLue', ')', '<=', '100'
79+
]
80+
assert_are_valid_tokens(columns=['finishing', 'product'],
81+
tokens=tokens,
82+
result_value_name='product_value')
2383

2484

2585
class TestModelGenerationBase(unittest.TestCase):
@@ -72,7 +132,7 @@ def replace_objective_token(self, objective, old, new):
72132

73133
def replace_constraint_token(self, constraint, old, new):
74134
def replace_one_constraint(c):
75-
c = copy.copy(c)
135+
c = copy.deepcopy(c)
76136
for i, token in enumerate(c["tokens"]):
77137
if token == old:
78138
c["tokens"][i] = new

0 commit comments

Comments
 (0)