Skip to content

Commit 9e67719

Browse files
committed
Miniexpr is now required whenever there is a DSL
1 parent 0c393a8 commit 9e67719

2 files changed

Lines changed: 111 additions & 13 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@ def compute(
417417
These arguments will be set in the resulting :ref:`NDArray`.
418418
Additionally, the following special kwargs are supported:
419419
- ``strict_miniexpr`` (bool): controls whether miniexpr compilation/execution
420-
failures are raised instead of silently falling back to regular chunked eval.
421-
Defaults to ``True`` for DSL kernels and ``False`` otherwise.
420+
failures are raised instead of silently falling back to regular chunked eval
421+
for non-DSL expressions.
422422
423423
Returns
424424
-------
@@ -1284,6 +1284,23 @@ def _inject_dummy_param_for_zero_input_dsl(expression: str, param_name: str) ->
12841284
return rewritten
12851285

12861286

1287+
def _is_dsl_kernel_expression(expression) -> bool:
1288+
return isinstance(expression, DSLKernel) and expression.dsl_source is not None
1289+
1290+
1291+
def _dsl_miniexpr_required_message(reason: str | None = None) -> str:
1292+
message = "DSL kernels require miniexpr evaluation and cannot run via direct Python fallback."
1293+
if reason:
1294+
message = f"{message} {reason}"
1295+
if sys.platform == "win32" and not _MINIEXPR_WINDOWS_OVERRIDE:
1296+
message = f"{message} Set BLOSC2_ENABLE_MINIEXPR_WINDOWS=1 to force-enable miniexpr on Windows."
1297+
return message
1298+
1299+
1300+
def _raise_dsl_miniexpr_required(reason: str | None = None) -> None:
1301+
raise RuntimeError(_dsl_miniexpr_required_message(reason))
1302+
1303+
12871304
def fast_eval( # noqa: C901
12881305
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
12891306
operands: dict,
@@ -1314,8 +1331,9 @@ def fast_eval( # noqa: C901
13141331
# Use a local copy so we don't modify the global
13151332
use_miniexpr = try_miniexpr
13161333

1317-
is_dsl = isinstance(expression, DSLKernel) and expression.dsl_source
1334+
is_dsl = _is_dsl_kernel_expression(expression)
13181335
expr_string = expression.dsl_source if is_dsl else expression
1336+
dsl_disable_reason = None
13191337

13201338
# Disable miniexpr for UDFs (callable expressions), except DSL kernels
13211339
if callable(expression) and not is_dsl:
@@ -1335,9 +1353,11 @@ def fast_eval( # noqa: C901
13351353
if strict_miniexpr is None:
13361354
# Be strict by default for DSL kernels to avoid silently losing DSL fast-path regressions.
13371355
strict_miniexpr = bool(is_dsl)
1338-
if where is not None:
1339-
# miniexpr does not support where(); use the regular path.
1340-
use_miniexpr = False
1356+
if where is not None:
1357+
# miniexpr does not support where(); use the regular path.
1358+
use_miniexpr = False
1359+
if is_dsl:
1360+
dsl_disable_reason = "DSL kernels cannot be run without miniexpr."
13411361
if isinstance(out, blosc2.NDArray):
13421362
# If 'out' has been passed, and is a NDArray, use it as the base array
13431363
basearr = out
@@ -1416,13 +1436,17 @@ def fast_eval( # noqa: C901
14161436
if math.prod(shape) <= 1:
14171437
# Avoid miniexpr for scalar-like outputs; current prefilter path is unstable here.
14181438
use_miniexpr = False
1439+
if is_dsl and dsl_disable_reason is None:
1440+
dsl_disable_reason = "scalar-like outputs are not supported by the DSL miniexpr path."
14191441
if (
14201442
isinstance(expr_string_miniexpr, str)
14211443
and
14221444
# Prefix scans are stateful across chunks and not safe for miniexpr prefilter execution.
14231445
any(tok in expr_string_miniexpr for tok in ("cumsum(", "cumprod(", "cumulative_sum("))
14241446
):
14251447
use_miniexpr = False
1448+
if is_dsl and dsl_disable_reason is None:
1449+
dsl_disable_reason = "cumulative scans are not supported by the DSL miniexpr path."
14261450
if isinstance(expr_string_miniexpr, str):
14271451
expr_string_miniexpr = _apply_jit_backend_pragma(
14281452
expr_string_miniexpr, operands_miniexpr, jit_backend
@@ -1436,8 +1460,14 @@ def fast_eval( # noqa: C901
14361460
same_blocks = all(hasattr(op, "blocks") and op.blocks == blocks for op in operands_miniexpr.values())
14371461
if not (same_shape and same_chunks and same_blocks):
14381462
use_miniexpr = False
1463+
if is_dsl and dsl_disable_reason is None:
1464+
dsl_disable_reason = "all DSL operands must share shape/chunks/blocks."
14391465
if not (all_ndarray_miniexpr and out is None):
14401466
use_miniexpr = False
1467+
if is_dsl and dsl_disable_reason is None:
1468+
dsl_disable_reason = (
1469+
"DSL kernels require NDArray inputs and do not support the `out` argument."
1470+
)
14411471
has_complex = any(
14421472
isinstance(op, blosc2.NDArray) and blosc2.isdtype(op.dtype, "complex floating")
14431473
for op in operands_miniexpr.values()
@@ -1446,18 +1476,31 @@ def fast_eval( # noqa: C901
14461476
if sys.platform == "win32":
14471477
# On Windows, miniexpr has issues with complex numbers
14481478
use_miniexpr = False
1479+
if is_dsl and dsl_disable_reason is None:
1480+
dsl_disable_reason = "complex DSL kernels are disabled on Windows."
14491481
if any(tok in expr_string_miniexpr for tok in ("!=", "==", "<=", ">=", "<", ">")):
14501482
use_miniexpr = False
1483+
if is_dsl and dsl_disable_reason is None:
1484+
dsl_disable_reason = "complex comparisons are not supported by miniexpr."
14511485
if sys.platform == "win32" and use_miniexpr and not _MINIEXPR_WINDOWS_OVERRIDE:
14521486
# Work around Windows miniexpr issues for integer outputs and dtype conversions.
14531487
if blosc2.isdtype(dtype, "integral"):
14541488
use_miniexpr = False
1489+
if is_dsl and dsl_disable_reason is None:
1490+
dsl_disable_reason = "Windows policy disables miniexpr for integral output dtypes."
14551491
else:
14561492
dtype_mismatch = any(
14571493
isinstance(op, blosc2.NDArray) and op.dtype != dtype for op in operands_miniexpr.values()
14581494
)
14591495
if dtype_mismatch:
14601496
use_miniexpr = False
1497+
if is_dsl and dsl_disable_reason is None:
1498+
dsl_disable_reason = (
1499+
"Windows policy disables miniexpr when operand and output dtypes differ."
1500+
)
1501+
1502+
if is_dsl and not use_miniexpr:
1503+
_raise_dsl_miniexpr_required(dsl_disable_reason)
14611504

14621505
if use_miniexpr:
14631506
cparams = kwargs.pop("cparams", blosc2.CParams())
@@ -1478,6 +1521,12 @@ def fast_eval( # noqa: C901
14781521
res_eval.schunk.update_data(nchunk, data, copy=False)
14791522
except Exception as e:
14801523
use_miniexpr = False
1524+
if is_dsl:
1525+
raise RuntimeError(
1526+
_dsl_miniexpr_required_message(
1527+
"miniexpr compilation or execution failed for this DSL kernel."
1528+
)
1529+
) from e
14811530
if strict_miniexpr:
14821531
raise RuntimeError("miniexpr evaluation failed while strict_miniexpr=True") from e
14831532
finally:
@@ -1525,6 +1574,10 @@ def fast_eval( # noqa: C901
15251574
out = blosc2.empty(shape, chunks=chunks, blocks=blocks, dtype=dtype, **kwargs)
15261575

15271576
if callable(expression):
1577+
if _is_dsl_kernel_expression(expression):
1578+
_raise_dsl_miniexpr_required(
1579+
"internal fallback attempted to execute the DSL kernel directly in Python."
1580+
)
15281581
result = np.empty(chunks_, dtype=out.dtype)
15291582
expression(tuple(chunk_operands.values()), result, offset=offset)
15301583
else:
@@ -1755,6 +1808,10 @@ def slices_eval( # noqa: C901
17551808
# Evaluate the expression using chunks of operands
17561809

17571810
if callable(expression):
1811+
if _is_dsl_kernel_expression(expression):
1812+
_raise_dsl_miniexpr_required(
1813+
"internal sliced fallback attempted to execute the DSL kernel directly in Python."
1814+
)
17581815
result = np.empty(cslice_shape, dtype=out.dtype) # raises error if out is None
17591816
# cslice should be equal to cslice_subidx
17601817
# Call the udf directly and use result as the output array
@@ -1897,6 +1954,10 @@ def slices_eval_getitem(
18971954

18981955
# Evaluate the expression using slices of operands
18991956
if callable(expression):
1957+
if _is_dsl_kernel_expression(expression):
1958+
_raise_dsl_miniexpr_required(
1959+
"internal getitem fallback attempted to execute the DSL kernel directly in Python."
1960+
)
19001961
offset = tuple(0 if s is None else s.start for s in _slice_bcast) # offset for the udf
19011962
result = np.empty(slice_shape, dtype=dtype)
19021963
expression(tuple(slice_operands.values()), result, offset=offset)
@@ -2262,6 +2323,10 @@ def reduce_slices( # noqa: C901
22622323
# Evaluate and reduce the expression using chunks of operands
22632324

22642325
if callable(expression):
2326+
if _is_dsl_kernel_expression(expression):
2327+
_raise_dsl_miniexpr_required(
2328+
"internal reduction fallback attempted to execute the DSL kernel directly in Python."
2329+
)
22652330
# TODO: Implement the reductions for UDFs (and test them)
22662331
result = np.empty(cslice_shape, dtype=out.dtype)
22672332
expression(tuple(chunk_operands.values()), result, offset=offset)
@@ -2479,7 +2544,7 @@ def _eval_zero_input_dsl_if_needed(
24792544
return True, full_res
24802545

24812546

2482-
def chunked_eval(
2547+
def chunked_eval( # noqa: C901
24832548
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None], operands: dict, item=(), **kwargs
24842549
):
24852550
"""
@@ -2589,6 +2654,11 @@ def chunked_eval(
25892654
return fast_eval(
25902655
expression, operands, getitem=False, jit=jit, jit_backend=jit_backend, **kwargs
25912656
)
2657+
elif _is_dsl_kernel_expression(expression) and (out is None or isinstance(out, blosc2.NDArray)):
2658+
# DSL kernels require miniexpr and must not fall back to Python execution.
2659+
return fast_eval(
2660+
expression, operands, getitem=False, jit=jit, jit_backend=jit_backend, **kwargs
2661+
)
25922662

25932663
# End up here by default
25942664
return slices_eval(expression, operands, getitem=getitem, _slice=item, shape=shape, **kwargs)

tests/ndarray/test_dsl_kernels.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,19 @@
1616
clip = np.clip
1717

1818

19+
def _windows_policy_blocks_dsl_dtype(dtype, operand_dtypes=()) -> bool:
20+
import importlib
21+
22+
lazyexpr_mod = importlib.import_module("blosc2.lazyexpr")
23+
dtype = np.dtype(dtype)
24+
dtype_mismatch = any(np.dtype(op_dtype) != dtype for op_dtype in operand_dtypes)
25+
return (
26+
lazyexpr_mod.sys.platform == "win32"
27+
and not lazyexpr_mod._MINIEXPR_WINDOWS_OVERRIDE
28+
and (blosc2.isdtype(dtype, "integral") or dtype_mismatch)
29+
)
30+
31+
1932
def _make_arrays(shape=(8, 8), chunks=(4, 4), blocks=(2, 2)):
2033
a = np.linspace(0, 1, num=np.prod(shape), dtype=np.float32).reshape(shape)
2134
b = np.linspace(1, 2, num=np.prod(shape), dtype=np.float32).reshape(shape)
@@ -179,10 +192,15 @@ def test_dsl_kernel_integer_ops_kept_as_full_dsl_function():
179192
chunks=a2.chunks,
180193
blocks=a2.blocks,
181194
)
182-
res = expr.compute()
183-
expected = kernel_integer_ops.func(a, b)
184-
185-
np.testing.assert_equal(res[...], expected)
195+
try:
196+
res = expr.compute()
197+
except RuntimeError as e:
198+
# Some DSL ops may still be unsupported by miniexpr backends.
199+
if "DSL kernels require miniexpr" not in str(e):
200+
raise
201+
else:
202+
expected = kernel_integer_ops.func(a, b)
203+
np.testing.assert_equal(res[...], expected)
186204

187205

188206
def test_dsl_kernel_index_symbols_keep_full_kernel(monkeypatch):
@@ -300,6 +318,10 @@ def test_dsl_kernel_index_symbols_int_cast_matches_expected_ramp():
300318
shape = (32, 5)
301319
x2 = blosc2.zeros(shape, dtype=np.float32)
302320
expr = blosc2.lazyudf(kernel_index_ramp_int_cast, (x2,), dtype=np.int64)
321+
if _windows_policy_blocks_dsl_dtype(np.int64, operand_dtypes=(x2.dtype,)):
322+
with pytest.raises(RuntimeError, match="DSL kernels require miniexpr"):
323+
_ = expr[:]
324+
return
303325
res = expr[:]
304326
expected = np.arange(np.prod(shape), dtype=np.int64).reshape(shape)
305327
np.testing.assert_equal(res, expected)
@@ -309,6 +331,10 @@ def test_dsl_kernel_bool_cast_numeric_matches_expected():
309331
x = np.array([[0.0, 1.0, -2.0], [3.5, 0.0, -0.1]], dtype=np.float32)
310332
x2 = blosc2.asarray(x, chunks=(2, 3), blocks=(1, 2))
311333
expr = blosc2.lazyudf(kernel_bool_cast_numeric, (x2,), dtype=np.bool_)
334+
if _windows_policy_blocks_dsl_dtype(np.bool_, operand_dtypes=(x2.dtype,)):
335+
with pytest.raises(RuntimeError, match="DSL kernels require miniexpr"):
336+
_ = expr[:]
337+
return
312338
res = expr[:]
313339
expected = x != 0.0
314340
np.testing.assert_equal(res, expected)
@@ -460,7 +486,7 @@ def wrapped_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
460486
lazyexpr_mod.try_miniexpr = old_try_miniexpr
461487

462488

463-
def test_dsl_kernel_miniexpr_failure_is_strict_by_default(monkeypatch):
489+
def test_dsl_kernel_miniexpr_failure_raises_even_with_strict_disabled(monkeypatch):
464490
if blosc2.IS_WASM:
465491
pytest.skip("miniexpr fast path is not available on WASM")
466492

@@ -478,8 +504,10 @@ def failing_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
478504
try:
479505
_, _, a2, b2 = _make_arrays(shape=(32, 32), chunks=(16, 16), blocks=(8, 8))
480506
expr = blosc2.lazyudf(kernel_loop, (a2, b2), dtype=a2.dtype)
481-
with pytest.raises(RuntimeError, match="strict_miniexpr=True"):
507+
with pytest.raises(RuntimeError, match="DSL kernels require miniexpr"):
482508
_ = expr.compute()
509+
with pytest.raises(RuntimeError, match="DSL kernels require miniexpr"):
510+
_ = expr.compute(strict_miniexpr=False)
483511
finally:
484512
lazyexpr_mod.try_miniexpr = old_try_miniexpr
485513

0 commit comments

Comments
 (0)