Skip to content

Commit d386291

Browse files
committed
dsl_kernel wrapper does not reduce kernels to simple expressions anymore
1 parent 7952e1d commit d386291

2 files changed

Lines changed: 45 additions & 28 deletions

File tree

src/blosc2/dsl_kernel.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -141,31 +141,13 @@ def _extract_dsl(self, func):
141141
if func_node is None:
142142
raise ValueError("No function definition found for DSL extraction")
143143

144-
dsl_source_full = None
145-
if _PRINT_DSL_KERNEL:
146-
try:
147-
dsl_source_full = _DSLBuilder().build(func_node)
148-
func_name = getattr(func, "__name__", "<dsl_kernel>")
149-
print(f"[DSLKernel:{func_name}] dsl_source (full):")
150-
print(dsl_source_full[0])
151-
except Exception as exc:
152-
func_name = getattr(func, "__name__", "<dsl_kernel>")
153-
print(f"[DSLKernel:{func_name}] dsl_source (full) failed: {exc}")
154-
155-
reducer = _DSLReducer()
156-
reduced = reducer.reduce(func_node)
157-
if reduced is not None:
158-
if _PRINT_DSL_KERNEL:
159-
func_name = getattr(func, "__name__", "<dsl_kernel>")
160-
print(f"[DSLKernel:{func_name}] reduced_expr:")
161-
print(reduced[0])
162-
return reduced
163-
164-
if dsl_source_full is not None:
165-
return dsl_source_full
166-
167144
builder = _DSLBuilder()
168-
return builder.build(func_node)
145+
dsl_source, input_names = builder.build(func_node)
146+
if _PRINT_DSL_KERNEL:
147+
func_name = getattr(func, "__name__", "<dsl_kernel>")
148+
print(f"[DSLKernel:{func_name}] dsl_source (full):")
149+
print(dsl_source)
150+
return dsl_source, input_names
169151

170152
def __call__(self, inputs_tuple, output, offset=None):
171153
if self._legacy_udf_signature:

tests/ndarray/test_dsl_kernels.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,14 @@ def kernel_fallback_tuple_assign(x, y):
112112
return lhs + rhs
113113

114114

115-
def test_dsl_kernel_reduced_expr():
115+
@blosc2.dsl_kernel
116+
def kernel_index_ramp(x):
117+
return _i0 * _n1 + _i1 # noqa: F821 # DSL index/shape symbols resolved by miniexpr
118+
119+
120+
def test_dsl_kernel_loop_kept_as_full_dsl_function():
116121
assert kernel_loop.dsl_source is not None
117-
assert "def " not in kernel_loop.dsl_source
122+
assert "def kernel_loop(x, y):" in kernel_loop.dsl_source
118123
assert kernel_loop.input_names == ["x", "y"]
119124

120125
a, b, a2, b2 = _make_arrays()
@@ -125,9 +130,9 @@ def test_dsl_kernel_reduced_expr():
125130
np.testing.assert_allclose(res[...], expected, rtol=1e-5, atol=1e-6)
126131

127132

128-
def test_dsl_kernel_integer_ops_reduced_expr():
133+
def test_dsl_kernel_integer_ops_kept_as_full_dsl_function():
129134
assert kernel_integer_ops.dsl_source is not None
130-
assert "def " not in kernel_integer_ops.dsl_source
135+
assert "def kernel_integer_ops(x, y):" in kernel_integer_ops.dsl_source
131136
assert kernel_integer_ops.input_names == ["x", "y"]
132137

133138
a, b, a2, b2 = _make_int_arrays()
@@ -144,6 +149,36 @@ def test_dsl_kernel_integer_ops_reduced_expr():
144149
np.testing.assert_equal(res[...], expected)
145150

146151

152+
def test_dsl_kernel_index_symbols_keep_full_kernel(monkeypatch):
153+
if blosc2.IS_WASM:
154+
pytest.skip("miniexpr fast path is not available on WASM")
155+
156+
assert kernel_index_ramp.dsl_source is not None
157+
assert "def kernel_index_ramp(x):" in kernel_index_ramp.dsl_source
158+
159+
original_set_pref_expr = blosc2.NDArray._set_pref_expr
160+
captured = {"calls": 0, "expr": None}
161+
162+
def wrapped_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None, jit=None):
163+
captured["calls"] += 1
164+
captured["expr"] = expression.decode("utf-8") if isinstance(expression, bytes) else expression
165+
return original_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc, jit=jit)
166+
167+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_expr", wrapped_set_pref_expr)
168+
169+
shape = (10, 10)
170+
x2 = blosc2.zeros(shape, dtype=np.float32)
171+
expr = blosc2.lazyudf(kernel_index_ramp, (x2,), dtype=np.float32)
172+
res = expr[:]
173+
174+
assert captured["calls"] >= 1
175+
assert "def kernel_index_ramp(x):" in captured["expr"]
176+
assert "_i0" in captured["expr"]
177+
assert "_n1" in captured["expr"]
178+
assert "_i1" in captured["expr"]
179+
assert res.shape == shape
180+
181+
147182
def test_dsl_kernel_full_control_flow_kept_as_dsl_function():
148183
assert kernel_control_flow_full.dsl_source is not None
149184
assert "def kernel_control_flow_full(x, y):" in kernel_control_flow_full.dsl_source

0 commit comments

Comments
 (0)