Skip to content

Commit 82ad9cb

Browse files
authored
fix optimize bracket depth bug (#2490)
1 parent 6012501 commit 82ad9cb

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

python/sqlflow_submitter/optimize/model_generation_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,24 @@ def setUp(self):
6969

7070
self.variables = ["product"]
7171

72-
def test_main(self):
72+
def test_multiple_brackets(self):
73+
constraint = {
74+
"expression": [
75+
'SUM', '(', 'finishing', '*', 'product', '+', 'SUM', '(',
76+
'product', ')', ')', '<=', '100'
77+
]
78+
}
79+
c0 = self.generate_constraint_func(constraint,
80+
result_value_name='product')
81+
c1 = self.generate_constraint_func(constraint,
82+
result_value_name="product_value")
83+
self.assertEqual(get_source(c0), get_source(c1))
84+
self.assertEqual(
85+
get_source(c0),
86+
"sum([DATA_FRAME.finishing[i_0]*model.x[i_0]+sum([model.x[i_1] for i_1 in model.x]) for i_0 in model.x])<=100"
87+
)
88+
89+
def test_model_generation(self):
7390
objective = [
7491
'SUM', '(', '(', 'price', '-', 'materials_cost', '-', 'other_cost',
7592
')', '*', 'product', ')'

python/sqlflow_submitter/optimize/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def get_depth(idx):
197197
if idx < l or idx > r:
198198
continue
199199

200-
if max_depth_idx < 0 or bracket_indices[max_depth_idx] < d:
200+
if max_depth_idx < 0 or bracket_indices[max_depth_idx][2] < d:
201201
max_depth_idx = k
202202

203203
k += 1

0 commit comments

Comments
 (0)