Skip to content

Commit 8bb4131

Browse files
committed
Do a better job showing bad DSL kernels
1 parent 1c91b2f commit 8bb4131

2 files changed

Lines changed: 43 additions & 9 deletions

File tree

src/blosc2/dsl_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,6 @@ def _extract_dsl(self, func):
521521
dsl_func = next((node for node in dsl_tree.body if isinstance(node, ast.FunctionDef)), None)
522522
if dsl_func is None:
523523
raise ValueError("No function definition found in sliced DSL source")
524-
_DSLValidator(dsl_source).validate(dsl_func)
525524
input_names = self._input_names_from_signature(dsl_func)
526525
if _PRINT_DSL_KERNEL:
527526
func_name = getattr(func, "__name__", "<dsl_kernel>")

src/blosc2/lazyexpr.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
import blosc2
4444

45-
from .dsl_kernel import DSLKernel, DSLSyntaxError, specialize_miniexpr_inputs, validate_dsl
45+
from .dsl_kernel import DSLKernel, DSLSyntaxError, _DSLValidator, specialize_miniexpr_inputs
4646

4747
if blosc2._HAS_NUMBA:
4848
import numba
@@ -1278,6 +1278,29 @@ def _is_dsl_kernel_expression(expression) -> bool:
12781278
return isinstance(expression, DSLKernel) and expression.dsl_source is not None
12791279

12801280

1281+
def _format_dsl_parse_error_hint(expr_text: str, backend_msg: str):
1282+
marker = "parse_error_pos="
1283+
pos0 = backend_msg.find(marker)
1284+
if pos0 < 0:
1285+
return None
1286+
pos0 += len(marker)
1287+
pos1 = pos0
1288+
while pos1 < len(backend_msg) and backend_msg[pos1].isdigit():
1289+
pos1 += 1
1290+
if pos1 == pos0:
1291+
return None
1292+
err_pos = int(backend_msg[pos0:pos1])
1293+
if err_pos < 0:
1294+
return None
1295+
if err_pos > len(expr_text):
1296+
err_pos = len(expr_text)
1297+
line_no = expr_text.count("\n", 0, err_pos) + 1
1298+
line_start = expr_text.rfind("\n", 0, err_pos) + 1
1299+
col_no = err_pos - line_start + 1
1300+
dump = _DSLValidator(expr_text)._format_source_with_pointer(line_no, col_no)
1301+
return f"Parse error location (line {line_no}, col {col_no}, offset {err_pos}):\n{dump}"
1302+
1303+
12811304
def _dsl_miniexpr_required_message(reason: str | None = None) -> str:
12821305
message = "DSL kernel requires miniexpr."
12831306
if reason:
@@ -1495,12 +1518,13 @@ def fast_eval( # noqa: C901
14951518
use_miniexpr = False
14961519
if is_dsl:
14971520
reason = "miniexpr compilation or execution failed for this DSL kernel."
1498-
if isinstance(expression, DSLKernel):
1499-
report = validate_dsl(expression)
1500-
if not report["valid"] and report["error"]:
1501-
reason = report["error"]
1502-
else:
1503-
reason = f"{reason}\nBackend error: {e}"
1521+
backend_error = str(e)
1522+
parse_hint = None
1523+
if isinstance(expr_string_miniexpr, str):
1524+
parse_hint = _format_dsl_parse_error_hint(expr_string_miniexpr, backend_error)
1525+
reason = f"{reason}\nBackend error: {backend_error}"
1526+
if parse_hint is not None:
1527+
reason = f"{reason}\n{parse_hint}"
15041528
raise RuntimeError(_dsl_miniexpr_required_message(reason)) from e
15051529
if strict_miniexpr:
15061530
raise RuntimeError("miniexpr evaluation failed while strict_miniexpr=True") from e
@@ -2223,6 +2247,7 @@ def reduce_slices( # noqa: C901
22232247
# For other operations, zeros should be safe
22242248
aux_reduc = np.zeros(nblocks, dtype=dtype)
22252249
prefilter_set = False
2250+
expression_miniexpr = None
22262251
try:
22272252
if where is not None:
22282253
expression_miniexpr = f"{reduce_op_str}(where({expression}, _where_x, _where_y))"
@@ -2238,8 +2263,18 @@ def reduce_slices( # noqa: C901
22382263
# Exercise prefilter for each chunk
22392264
for nchunk in range(res_eval.schunk.nchunks):
22402265
res_eval.schunk._prefilter_data(nchunk, data, chunk_data)
2241-
except Exception:
2266+
except Exception as e:
22422267
use_miniexpr = False
2268+
if callable(expression) and _is_dsl_kernel_expression(expression):
2269+
reason = "miniexpr compilation or execution failed for this DSL kernel."
2270+
backend_error = str(e)
2271+
parse_hint = None
2272+
if isinstance(expression_miniexpr, str):
2273+
parse_hint = _format_dsl_parse_error_hint(expression_miniexpr, backend_error)
2274+
reason = f"{reason}\nBackend error: {backend_error}"
2275+
if parse_hint is not None:
2276+
reason = f"{reason}\n{parse_hint}"
2277+
raise RuntimeError(_dsl_miniexpr_required_message(reason)) from e
22432278
finally:
22442279
if prefilter_set:
22452280
res_eval.schunk.remove_prefilter("miniexpr")

0 commit comments

Comments
 (0)