Skip to content

Commit 7d75ab4

Browse files
committed
Better diagnostics for failed DSL compilation/execution
1 parent 70aec29 commit 7d75ab4

4 files changed

Lines changed: 61 additions & 40 deletions

File tree

examples/ndarray/malformed_dsl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from blosc2.dsl_kernel import DSLSyntaxError
1616

1717

18-
# --- 1) Malformed DSL syntax: validate_dsl() and lazyudf() diagnostics ---
18+
# --- 1) Malformed DSL syntax: validate_dsl() diagnostics ---
1919
@blosc2.dsl_kernel
2020
def kernel_bad_ternary(x):
2121
return 1 if x else 0
@@ -25,14 +25,15 @@ def kernel_bad_ternary(x):
2525
print("validate_dsl valid:", report["valid"])
2626
print("validate_dsl error:\n", report["error"])
2727

28+
# --- 2) Proper error is raised when trying to compute as well ---
29+
x = blosc2.ones((8, 8), dtype=np.float32)
2830
try:
29-
x = blosc2.ones((8, 8), dtype=np.float32)
30-
_ = blosc2.lazyudf(kernel_bad_ternary, (x,), dtype=np.int32)
31+
res = blosc2.lazyudf(kernel_bad_ternary, (x,), dtype=np.int32)[:]
3132
except DSLSyntaxError as e:
3233
print("\nlazyudf rejected malformed DSL kernel as expected:\n", e)
3334

3435

35-
# --- 2) Force miniexpr backend failure to show enriched RuntimeError message ---
36+
# --- 3) Force miniexpr backend failure to show enriched RuntimeError message ---
3637
@blosc2.dsl_kernel
3738
def kernel_ok(x, y):
3839
return x + y

src/blosc2/dsl_kernel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,16 +490,20 @@ def __init__(self, func):
490490
try:
491491
dsl_source, input_names = self._extract_dsl(func)
492492
except DSLSyntaxError as e:
493-
dsl_source = None
494-
input_names = None
493+
# Preserve extracted source/signature for diagnostics even when DSL validation fails.
494+
try:
495+
dsl_source, input_names = self._extract_dsl(func, validate=False)
496+
except Exception:
497+
dsl_source = None
498+
input_names = None
495499
self.dsl_error = e
496500
except Exception:
497501
dsl_source = None
498502
input_names = None
499503
self.dsl_source = dsl_source
500504
self.input_names = input_names
501505

502-
def _extract_dsl(self, func):
506+
def _extract_dsl(self, func, validate: bool = True):
503507
source = inspect.getsource(func)
504508
source = textwrap.dedent(source)
505509
tree = ast.parse(source)
@@ -521,6 +525,8 @@ def _extract_dsl(self, func):
521525
dsl_func = next((node for node in dsl_tree.body if isinstance(node, ast.FunctionDef)), None)
522526
if dsl_func is None:
523527
raise ValueError("No function definition found in sliced DSL source")
528+
if validate:
529+
_DSLValidator(dsl_source).validate(dsl_func)
524530
input_names = self._input_names_from_signature(dsl_func)
525531
if _PRINT_DSL_KERNEL:
526532
func_name = getattr(func, "__name__", "<dsl_kernel>")

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,9 +1302,9 @@ def _format_dsl_parse_error_hint(expr_text: str, backend_msg: str):
13021302

13031303

13041304
def _dsl_miniexpr_required_message(reason: str | None = None) -> str:
1305-
message = "DSL kernel requires miniexpr."
1305+
message = ""
13061306
if reason:
1307-
message = f"{message} {reason}"
1307+
message = f"{message}{reason}"
13081308
return message
13091309

13101310

@@ -1517,7 +1517,9 @@ def fast_eval( # noqa: C901
15171517
except Exception as e:
15181518
use_miniexpr = False
15191519
if is_dsl:
1520-
reason = "miniexpr compilation or execution failed for this DSL kernel."
1520+
reason = (
1521+
f"miniexpr compilation or execution failed for this DSL kernel:\n{expression.dsl_source}"
1522+
)
15211523
backend_error = str(e)
15221524
parse_hint = None
15231525
if isinstance(expr_string_miniexpr, str):

tests/ndarray/test_dsl_kernels.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
import blosc2
19+
from blosc2.dsl_kernel import DSLSyntaxError
1920
from blosc2.lazyexpr import _apply_jit_backend_pragma
2021

2122
where = np.where
@@ -211,7 +212,7 @@ def test_dsl_kernel_integer_ops_kept_as_full_dsl_function():
211212
res = expr.compute()
212213
except RuntimeError as e:
213214
# Some DSL ops may still be unsupported by miniexpr backends.
214-
if "DSL kernel requires miniexpr" not in str(e):
215+
if "miniexpr compilation or execution failed for this DSL kernel" not in str(e):
215216
raise
216217
else:
217218
expected = kernel_integer_ops.func(a, b)
@@ -537,7 +538,9 @@ def failing_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
537538
try:
538539
shape = (8, 8)
539540
expr = blosc2.lazyudf(kernel_scalar_only, (3,), dtype=np.float32, shape=shape)
540-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr"):
541+
with pytest.raises(
542+
RuntimeError, match="miniexpr compilation or execution failed for this DSL kernel"
543+
):
541544
expr.compute()
542545
assert captured["calls"] >= 1
543546
assert captured["keys"] == ("__me_dummy0",)
@@ -671,9 +674,13 @@ def failing_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
671674
try:
672675
_, _, a2, b2 = _make_arrays(shape=(32, 32), chunks=(16, 16), blocks=(8, 8))
673676
expr = blosc2.lazyudf(kernel_loop, (a2, b2), dtype=a2.dtype)
674-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr"):
677+
with pytest.raises(
678+
RuntimeError, match="miniexpr compilation or execution failed for this DSL kernel"
679+
):
675680
_ = expr.compute()
676-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr"):
681+
with pytest.raises(
682+
RuntimeError, match="miniexpr compilation or execution failed for this DSL kernel"
683+
):
677684
_ = expr.compute(strict_miniexpr=False)
678685
finally:
679686
lazyexpr_mod.try_miniexpr = old_try_miniexpr
@@ -694,7 +701,9 @@ def failing_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
694701
try:
695702
_, _, a2, b2 = _make_arrays(shape=(32, 32), chunks=(16, 16), blocks=(8, 8))
696703
expr = blosc2.lazyudf(kernel_loop, (a2, b2), dtype=a2.dtype)
697-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr") as excinfo:
704+
with pytest.raises(
705+
RuntimeError, match="miniexpr compilation or execution failed for this DSL kernel"
706+
) as excinfo:
698707
_ = expr.compute()
699708
msg = str(excinfo.value)
700709
assert "Backend error: forced miniexpr backend failure details" in msg
@@ -726,7 +735,9 @@ def fake_validate_dsl(_func):
726735
try:
727736
_, _, a2, b2 = _make_arrays(shape=(32, 32), chunks=(16, 16), blocks=(8, 8))
728737
expr = blosc2.lazyudf(kernel_loop, (a2, b2), dtype=a2.dtype)
729-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr") as excinfo:
738+
with pytest.raises(
739+
RuntimeError, match="miniexpr compilation or execution failed for this DSL kernel"
740+
) as excinfo:
730741
_ = expr.compute()
731742
msg = str(excinfo.value)
732743
assert "Backend error: forced backend failure hidden by validate_dsl message" in msg
@@ -785,37 +796,35 @@ def test_jit_backend_pragma_wrapping_dsl_source():
785796
def test_dsl_kernel_flawed_syntax_detected_fallback_callable(kernel):
786797
assert kernel.dsl_source is not None
787798
assert kernel.input_names == ["x", "y"]
788-
assert kernel.dsl_error is None
799+
assert kernel.dsl_error is not None
789800

790801
a, b, a2, b2 = _make_arrays()
791-
expr = blosc2.lazyudf(
792-
kernel,
793-
(a2, b2),
794-
dtype=a2.dtype,
795-
chunks=a2.chunks,
796-
blocks=a2.blocks,
797-
)
798-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr"):
799-
_ = expr.compute()
802+
with pytest.raises(DSLSyntaxError, match="Invalid DSL kernel"):
803+
_ = blosc2.lazyudf(
804+
kernel,
805+
(a2, b2),
806+
dtype=a2.dtype,
807+
chunks=a2.chunks,
808+
blocks=a2.blocks,
809+
)
800810

801811

802812
def test_dsl_kernel_ternary_rejected_with_actionable_error():
803813
assert kernel_fallback_ternary.dsl_source is not None
804-
assert kernel_fallback_ternary.dsl_error is None
814+
assert kernel_fallback_ternary.input_names == ["x"]
815+
assert kernel_fallback_ternary.dsl_error is not None
805816

806817
_, _, a2, _ = _make_arrays()
807-
expr = blosc2.lazyudf(
808-
kernel_fallback_ternary,
809-
(a2,),
810-
dtype=np.int32,
811-
chunks=a2.chunks,
812-
blocks=a2.blocks,
813-
)
814-
with pytest.raises(RuntimeError, match="DSL kernel requires miniexpr") as excinfo:
815-
_ = expr.compute()
818+
with pytest.raises(DSLSyntaxError, match="Invalid DSL kernel") as excinfo:
819+
_ = blosc2.lazyudf(
820+
kernel_fallback_ternary,
821+
(a2,),
822+
dtype=np.int32,
823+
chunks=a2.chunks,
824+
blocks=a2.blocks,
825+
)
816826
msg = str(excinfo.value)
817-
assert "Backend error:" in msg
818-
assert "Parse error location" in msg
827+
assert "Ternary expressions are not supported in DSL; use where(cond, a, b)" in msg
819828
assert "^" in msg
820829

821830

@@ -827,10 +836,13 @@ def test_validate_dsl_api_valid_and_invalid():
827836
assert valid_report["input_names"] == ["x", "y"]
828837

829838
unsupported_report = blosc2.validate_dsl(kernel_fallback_ternary)
830-
assert unsupported_report["valid"] is True
831-
assert unsupported_report["error"] is None
839+
assert unsupported_report["valid"] is False
840+
assert unsupported_report["error"] is not None
832841
assert "def kernel_fallback_ternary(x):" in unsupported_report["dsl_source"]
833842
assert unsupported_report["input_names"] == ["x"]
843+
assert (
844+
"Ternary expressions are not supported in DSL; use where(cond, a, b)" in unsupported_report["error"]
845+
)
834846

835847

836848
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)