Skip to content

Commit 203b4d4

Browse files
committed
math funcs should go without np. prefix in dsl kernels
1 parent d489fb2 commit 203b4d4

3 files changed

Lines changed: 93 additions & 27 deletions

File tree

bench/b2nd/jit-dsl.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
import blosc2
1919
import numpy as np
2020

21+
where = np.where
22+
sin = np.sin
23+
cos = np.cos
24+
exp = np.exp
25+
log = np.log
26+
2127

2228
@blosc2.dsl_kernel
2329
def k_dsl(x, y):
@@ -27,7 +33,7 @@ def k_dsl(x, y):
2733
if i == 0:
2834
acc = acc + y
2935
else:
30-
acc = np.where(acc < y, acc + i, acc - i)
36+
acc = where(acc < y, acc + i, acc - i)
3137
i = i + 1
3238
return acc
3339

@@ -37,15 +43,15 @@ def k_heavy_dsl(x, y, niter):
3743
acc = x
3844
i = 0
3945
while i < niter:
40-
t = np.sin(acc * 1.001 + y * 0.123)
41-
u = np.cos(acc * 0.777 - y * 0.211)
42-
v = np.exp(t * 0.25) - np.log(np.abs(u) + 1.0)
43-
p = np.sin(v * 0.731 + acc * 0.071)
44-
q = np.cos(v * 0.379 - y * 0.053)
45-
r = np.exp((p - q) * 0.17) - np.log(np.abs(p + q) + 1.0)
46-
w = np.sin((r + v) * 0.11) + np.cos((r - v) * 0.07)
46+
t = sin(acc * 1.001 + y * 0.123)
47+
u = cos(acc * 0.777 - y * 0.211)
48+
v = exp(t * 0.25) - log(abs(u) + 1.0)
49+
p = sin(v * 0.731 + acc * 0.071)
50+
q = cos(v * 0.379 - y * 0.053)
51+
r = exp((p - q) * 0.17) - log(abs(p + q) + 1.0)
52+
w = sin((r + v) * 0.11) + cos((r - v) * 0.07)
4753
delta = v + r + w
48-
acc = np.where((acc < y), (acc + delta), (acc - delta))
54+
acc = where((acc < y), (acc + delta), (acc - delta))
4955
i = i + 1
5056
return acc
5157

bench/ndarray/dsl-kernel-bench.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
import importlib
1515

1616
lazyexpr_mod = importlib.import_module("blosc2.lazyexpr")
17+
where = np.where
1718

1819

1920
@blosc2.dsl_kernel
2021
def kernel_loop1(x, y):
2122
acc = 0.0
2223
for i in range(1):
2324
if i % 2 == 0:
24-
tmp = np.where(x < y, y + i, x - i)
25+
tmp = where(x < y, y + i, x - i)
2526
else:
26-
tmp = np.where(x > y, x + i, y - i)
27+
tmp = where(x > y, x + i, y - i)
2728
acc = acc + tmp * (i + 1)
2829
return acc
2930

@@ -33,9 +34,9 @@ def kernel_loop2(x, y):
3334
acc = 0.0
3435
for i in range(2):
3536
if i % 2 == 0:
36-
tmp = np.where(x < y, y + i, x - i)
37+
tmp = where(x < y, y + i, x - i)
3738
else:
38-
tmp = np.where(x > y, x + i, y - i)
39+
tmp = where(x > y, x + i, y - i)
3940
acc = acc + tmp * (i + 1)
4041
return acc
4142

@@ -45,9 +46,9 @@ def kernel_loop4(x, y):
4546
acc = 0.0
4647
for i in range(4):
4748
if i % 2 == 0:
48-
tmp = np.where(x < y, y + i, x - i)
49+
tmp = where(x < y, y + i, x - i)
4950
else:
50-
tmp = np.where(x > y, x + i, y - i)
51+
tmp = where(x > y, x + i, y - i)
5152
acc = acc + tmp * (i + 1)
5253
return acc
5354

@@ -57,9 +58,9 @@ def kernel_loop4_heavy(x, y):
5758
acc = 0.0
5859
for i in range(4):
5960
if i % 2 == 0:
60-
tmp = np.where(x < y, y + i, x - i)
61+
tmp = where(x < y, y + i, x - i)
6162
else:
62-
tmp = np.where(x > y, x + i, y - i)
63+
tmp = where(x > y, x + i, y - i)
6364
acc = acc + tmp * (i + 1) + (tmp * tmp) * 0.05
6465
return acc
6566

@@ -70,9 +71,9 @@ def kernel_nested2(x, y):
7071
for i in range(2):
7172
for j in range(2):
7273
if (i + j) % 2 == 0:
73-
tmp = np.where(x < y, y + i + j, x - i - j)
74+
tmp = where(x < y, y + i + j, x - i - j)
7475
else:
75-
tmp = np.where(x > y, x + i + j, y - i - j)
76+
tmp = where(x > y, x + i + j, y - i - j)
7677
acc = acc + tmp * (i + j + 1)
7778
return acc
7879

@@ -148,7 +149,7 @@ def bench_case(name, kernel, expr, a, b, dtype, gb):
148149
res_dsl = lazy_dsl.compute()
149150
dsl_time, _ = time_it(lambda: lazy_dsl.compute())
150151

151-
np.testing.assert_allclose(res_dsl[...], res_base[...], rtol=1e-5, atol=1e-6)
152+
np.testing.assert_allclose(res_dsl[...], res_base[...], rtol=1e-5, atol=2e-6)
152153

153154
return {
154155
"case": name,

tests/ndarray/test_dsl_kernels.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from blosc2.dsl_kernel import DSLSyntaxError
1313
from blosc2.lazyexpr import _apply_jit_backend_pragma
1414

15+
where = np.where
16+
clip = np.clip
17+
1518

1619
def _make_arrays(shape=(8, 8), chunks=(4, 4), blocks=(2, 2)):
1720
a = np.linspace(0, 1, num=np.prod(shape), dtype=np.float32).reshape(shape)
@@ -34,9 +37,9 @@ def kernel_loop(x, y):
3437
acc = 0.0
3538
for i in range(2):
3639
if i % 2 == 0:
37-
tmp = np.where(x < y, y + i, x - i)
40+
tmp = where(x < y, y + i, x - i)
3841
else:
39-
tmp = np.where(x > y, x + i, y - i)
42+
tmp = where(x > y, x + i, y - i)
4043
acc = acc + tmp * (i + 1)
4144
return acc
4245

@@ -68,7 +71,7 @@ def kernel_control_flow_full(x, y):
6871
if i == 1:
6972
acc = acc - y
7073
else:
71-
acc = np.where(acc < y, acc + i, acc - i)
74+
acc = where(acc < y, acc + i, acc - i)
7275
if i == 3:
7376
break
7477
return acc
@@ -79,7 +82,7 @@ def kernel_while_full(x, y):
7982
acc = x
8083
i = 0
8184
while i < 3:
82-
acc = np.where(acc < y, acc + 1, acc - 1)
85+
acc = where(acc < y, acc + 1, acc - 1)
8386
i = i + 1
8487
return acc
8588

@@ -89,7 +92,7 @@ def kernel_loop_param(x, y, niter):
8992
acc = x
9093
# loop count comes from scalar niter
9194
for _i in range(niter):
92-
acc = np.where(acc < y, acc + 1, acc - 1)
95+
acc = where(acc < y, acc + 1, acc - 1)
9396
return acc
9497

9598

@@ -101,7 +104,7 @@ def kernel_scalar_float_cast(x, niter):
101104

102105
@blosc2.dsl_kernel
103106
def kernel_fallback_kw_call(x, y):
104-
return np.clip(x + y, a_min=0.5, a_max=2.5)
107+
return clip(x + y, a_min=0.5, a_max=2.5)
105108

106109

107110
@blosc2.dsl_kernel
@@ -135,6 +138,16 @@ def kernel_index_ramp_float_cast(x):
135138
return float(_i0) * _n1 + _i1 # noqa: F821 # DSL index/shape symbols resolved by miniexpr
136139

137140

141+
@blosc2.dsl_kernel
142+
def kernel_index_ramp_int_cast(x):
143+
return int(_i0 * _n1 + _i1) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
144+
145+
146+
@blosc2.dsl_kernel
147+
def kernel_bool_cast_numeric(x):
148+
return bool(x)
149+
150+
138151
@blosc2.dsl_kernel
139152
def kernel_index_ramp_no_inputs():
140153
return _i0 * _n1 + _i1 # noqa: F821 # DSL index/shape symbols resolved by miniexpr
@@ -255,14 +268,60 @@ def test_dsl_kernel_index_symbols_float_cast_matches_expected_ramp():
255268
np.testing.assert_allclose(res, expected, rtol=0.0, atol=0.0)
256269

257270

271+
def test_dsl_kernel_index_symbols_float_cast_uses_miniexpr_fast_path(monkeypatch):
272+
if blosc2.IS_WASM:
273+
pytest.skip("miniexpr fast path is not available on WASM")
274+
275+
original_set_pref_expr = blosc2.NDArray._set_pref_expr
276+
captured = {"calls": 0, "expr": None}
277+
278+
def wrapped_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None, jit=None):
279+
captured["calls"] += 1
280+
captured["expr"] = expression.decode("utf-8") if isinstance(expression, bytes) else expression
281+
return original_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc, jit=jit)
282+
283+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_expr", wrapped_set_pref_expr)
284+
285+
shape = (16, 9)
286+
x2 = blosc2.zeros(shape, dtype=np.float32)
287+
expr = blosc2.lazyudf(kernel_index_ramp_float_cast, (x2,), dtype=np.float32)
288+
res = expr[:]
289+
expected = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
290+
291+
np.testing.assert_allclose(res, expected, rtol=0.0, atol=0.0)
292+
assert captured["calls"] >= 1
293+
assert "def kernel_index_ramp_float_cast(x):" in captured["expr"]
294+
assert "float(_i0)" in captured["expr"]
295+
assert "_n1" in captured["expr"]
296+
assert "_i1" in captured["expr"]
297+
298+
299+
def test_dsl_kernel_index_symbols_int_cast_matches_expected_ramp():
300+
shape = (32, 5)
301+
x2 = blosc2.zeros(shape, dtype=np.float32)
302+
expr = blosc2.lazyudf(kernel_index_ramp_int_cast, (x2,), dtype=np.int64)
303+
res = expr[:]
304+
expected = np.arange(np.prod(shape), dtype=np.int64).reshape(shape)
305+
np.testing.assert_equal(res, expected)
306+
307+
308+
def test_dsl_kernel_bool_cast_numeric_matches_expected():
309+
x = np.array([[0.0, 1.0, -2.0], [3.5, 0.0, -0.1]], dtype=np.float32)
310+
x2 = blosc2.asarray(x, chunks=(2, 3), blocks=(1, 2))
311+
expr = blosc2.lazyudf(kernel_bool_cast_numeric, (x2,), dtype=np.bool_)
312+
res = expr[:]
313+
expected = x != 0.0
314+
np.testing.assert_equal(res, expected)
315+
316+
258317
def test_dsl_kernel_full_control_flow_kept_as_dsl_function():
259318
assert kernel_control_flow_full.dsl_source is not None
260319
assert "def kernel_control_flow_full(x, y):" in kernel_control_flow_full.dsl_source
261320
assert "for i in range(4):" in kernel_control_flow_full.dsl_source
262321
assert "if i == 1:" in kernel_control_flow_full.dsl_source
263322
assert "continue" in kernel_control_flow_full.dsl_source
264323
assert "break" in kernel_control_flow_full.dsl_source
265-
assert "np.where(" in kernel_control_flow_full.dsl_source
324+
assert "where(" in kernel_control_flow_full.dsl_source
266325

267326
a, b, a2, b2 = _make_arrays()
268327
expr = blosc2.lazyudf(

0 commit comments

Comments
 (0)