1212from blosc2 .dsl_kernel import DSLSyntaxError
1313from blosc2 .lazyexpr import _apply_jit_backend_pragma
1414
15+ where = np .where
16+ clip = np .clip
17+
1518
1619def _make_arrays (shape = (8 , 8 ), chunks = (4 , 4 ), blocks = (2 , 2 )):
1720 a = np .linspace (0 , 1 , num = np .prod (shape ), dtype = np .float32 ).reshape (shape )
@@ -34,9 +37,9 @@ def kernel_loop(x, y):
3437 acc = 0.0
3538 for i in range (2 ):
3639 if i % 2 == 0 :
37- tmp = np . where (x < y , y + i , x - i )
40+ tmp = where (x < y , y + i , x - i )
3841 else :
39- tmp = np . where (x > y , x + i , y - i )
42+ tmp = where (x > y , x + i , y - i )
4043 acc = acc + tmp * (i + 1 )
4144 return acc
4245
@@ -68,7 +71,7 @@ def kernel_control_flow_full(x, y):
6871 if i == 1 :
6972 acc = acc - y
7073 else :
71- acc = np . where (acc < y , acc + i , acc - i )
74+ acc = where (acc < y , acc + i , acc - i )
7275 if i == 3 :
7376 break
7477 return acc
@@ -79,7 +82,7 @@ def kernel_while_full(x, y):
7982 acc = x
8083 i = 0
8184 while i < 3 :
82- acc = np . where (acc < y , acc + 1 , acc - 1 )
85+ acc = where (acc < y , acc + 1 , acc - 1 )
8386 i = i + 1
8487 return acc
8588
@@ -89,7 +92,7 @@ def kernel_loop_param(x, y, niter):
8992 acc = x
9093 # loop count comes from scalar niter
9194 for _i in range (niter ):
92- acc = np . where (acc < y , acc + 1 , acc - 1 )
95+ acc = where (acc < y , acc + 1 , acc - 1 )
9396 return acc
9497
9598
@@ -101,7 +104,7 @@ def kernel_scalar_float_cast(x, niter):
101104
102105@blosc2 .dsl_kernel
103106def kernel_fallback_kw_call (x , y ):
104- return np . clip (x + y , a_min = 0.5 , a_max = 2.5 )
107+ return clip (x + y , a_min = 0.5 , a_max = 2.5 )
105108
106109
107110@blosc2 .dsl_kernel
@@ -135,6 +138,16 @@ def kernel_index_ramp_float_cast(x):
135138 return float (_i0 ) * _n1 + _i1 # noqa: F821 # DSL index/shape symbols resolved by miniexpr
136139
137140
141+ @blosc2 .dsl_kernel
142+ def kernel_index_ramp_int_cast (x ):
143+ return int (_i0 * _n1 + _i1 ) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
144+
145+
146+ @blosc2 .dsl_kernel
147+ def kernel_bool_cast_numeric (x ):
148+ return bool (x )
149+
150+
138151@blosc2 .dsl_kernel
139152def kernel_index_ramp_no_inputs ():
140153 return _i0 * _n1 + _i1 # noqa: F821 # DSL index/shape symbols resolved by miniexpr
@@ -255,14 +268,60 @@ def test_dsl_kernel_index_symbols_float_cast_matches_expected_ramp():
255268 np .testing .assert_allclose (res , expected , rtol = 0.0 , atol = 0.0 )
256269
257270
271+ def test_dsl_kernel_index_symbols_float_cast_uses_miniexpr_fast_path (monkeypatch ):
272+ if blosc2 .IS_WASM :
273+ pytest .skip ("miniexpr fast path is not available on WASM" )
274+
275+ original_set_pref_expr = blosc2 .NDArray ._set_pref_expr
276+ captured = {"calls" : 0 , "expr" : None }
277+
278+ def wrapped_set_pref_expr (self , expression , inputs , fp_accuracy , aux_reduc = None , jit = None ):
279+ captured ["calls" ] += 1
280+ captured ["expr" ] = expression .decode ("utf-8" ) if isinstance (expression , bytes ) else expression
281+ return original_set_pref_expr (self , expression , inputs , fp_accuracy , aux_reduc , jit = jit )
282+
283+ monkeypatch .setattr (blosc2 .NDArray , "_set_pref_expr" , wrapped_set_pref_expr )
284+
285+ shape = (16 , 9 )
286+ x2 = blosc2 .zeros (shape , dtype = np .float32 )
287+ expr = blosc2 .lazyudf (kernel_index_ramp_float_cast , (x2 ,), dtype = np .float32 )
288+ res = expr [:]
289+ expected = np .arange (np .prod (shape ), dtype = np .float32 ).reshape (shape )
290+
291+ np .testing .assert_allclose (res , expected , rtol = 0.0 , atol = 0.0 )
292+ assert captured ["calls" ] >= 1
293+ assert "def kernel_index_ramp_float_cast(x):" in captured ["expr" ]
294+ assert "float(_i0)" in captured ["expr" ]
295+ assert "_n1" in captured ["expr" ]
296+ assert "_i1" in captured ["expr" ]
297+
298+
299+ def test_dsl_kernel_index_symbols_int_cast_matches_expected_ramp ():
300+ shape = (32 , 5 )
301+ x2 = blosc2 .zeros (shape , dtype = np .float32 )
302+ expr = blosc2 .lazyudf (kernel_index_ramp_int_cast , (x2 ,), dtype = np .int64 )
303+ res = expr [:]
304+ expected = np .arange (np .prod (shape ), dtype = np .int64 ).reshape (shape )
305+ np .testing .assert_equal (res , expected )
306+
307+
308+ def test_dsl_kernel_bool_cast_numeric_matches_expected ():
309+ x = np .array ([[0.0 , 1.0 , - 2.0 ], [3.5 , 0.0 , - 0.1 ]], dtype = np .float32 )
310+ x2 = blosc2 .asarray (x , chunks = (2 , 3 ), blocks = (1 , 2 ))
311+ expr = blosc2 .lazyudf (kernel_bool_cast_numeric , (x2 ,), dtype = np .bool_ )
312+ res = expr [:]
313+ expected = x != 0.0
314+ np .testing .assert_equal (res , expected )
315+
316+
258317def test_dsl_kernel_full_control_flow_kept_as_dsl_function ():
259318 assert kernel_control_flow_full .dsl_source is not None
260319 assert "def kernel_control_flow_full(x, y):" in kernel_control_flow_full .dsl_source
261320 assert "for i in range(4):" in kernel_control_flow_full .dsl_source
262321 assert "if i == 1:" in kernel_control_flow_full .dsl_source
263322 assert "continue" in kernel_control_flow_full .dsl_source
264323 assert "break" in kernel_control_flow_full .dsl_source
265- assert "np. where(" in kernel_control_flow_full .dsl_source
324+ assert "where(" in kernel_control_flow_full .dsl_source
266325
267326 a , b , a2 , b2 = _make_arrays ()
268327 expr = blosc2 .lazyudf (
0 commit comments