Skip to content

Commit dda315d

Browse files
Merge pull request #574 from Blosc/dsl-kernel
This allows to use DSL kernels in miniexpr directly from `blosc2.lazyudf`. As miniexpr support JIT in DSLs, these can perform extraordinarily well inside prefilters in Blosc2 pipelines. For a demonstration, see e.g. `examples/ndarray/mandelbrot-dsl.ipynb`.
2 parents ff5ea66 + dc06083 commit dda315d

12 files changed

Lines changed: 2594 additions & 33 deletions

File tree

CMakeLists.txt

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,16 @@ set(MINIEXPR_BUILD_TESTS OFF CACHE BOOL "Build miniexpr tests" FORCE)
5656
set(MINIEXPR_BUILD_EXAMPLES OFF CACHE BOOL "Build miniexpr examples" FORCE)
5757
set(MINIEXPR_BUILD_BENCH OFF CACHE BOOL "Build miniexpr benchmarks" FORCE)
5858

59+
if(EMSCRIPTEN)
60+
# JIT in miniexpr for wasm32 exists already, but we need to do some work before we can use it
61+
# See plans/external-js-glue.md for details
62+
set(MINIEXPR_ENABLE_TCC_JIT OFF CACHE BOOL "TCC JIT unavailable in Emscripten side-module builds" FORCE)
63+
endif()
64+
5965
FetchContent_Declare(miniexpr
6066
GIT_REPOSITORY https://github.com/Blosc/miniexpr.git
61-
GIT_TAG 77d633cb2c134552da045b8d2cc0ad23908e6b9e
67+
GIT_TAG 4bf12c8d5eac4c2022db8567d7b3cee44a963c9c
68+
# SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../miniexpr
6269
)
6370
FetchContent_MakeAvailable(miniexpr)
6471

@@ -144,3 +151,33 @@ install(
144151
TARGETS blosc2_ext
145152
LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR}/blosc2
146153
)
154+
155+
# Install bundled libtcc next to the Python package (separate LGPL artifact).
156+
if(MINIEXPR_ENABLE_TCC_JIT)
157+
if(APPLE)
158+
install(
159+
FILES "${miniexpr_BINARY_DIR}/libtcc.dylib"
160+
DESTINATION ${SKBUILD_PLATLIB_DIR}/blosc2/lib
161+
OPTIONAL
162+
)
163+
elseif(WIN32)
164+
install(
165+
FILES
166+
"${miniexpr_BINARY_DIR}/tcc.dll"
167+
"${miniexpr_BINARY_DIR}/Debug/tcc.dll"
168+
"${miniexpr_BINARY_DIR}/Release/tcc.dll"
169+
"${miniexpr_BINARY_DIR}/RelWithDebInfo/tcc.dll"
170+
"${miniexpr_BINARY_DIR}/MinSizeRel/tcc.dll"
171+
DESTINATION ${SKBUILD_PLATLIB_DIR}/blosc2/lib
172+
OPTIONAL
173+
)
174+
else()
175+
install(
176+
FILES
177+
"${miniexpr_BINARY_DIR}/libtcc.so"
178+
"${miniexpr_BINARY_DIR}/libtcc.so.1"
179+
DESTINATION ${SKBUILD_PLATLIB_DIR}/blosc2/lib
180+
OPTIONAL
181+
)
182+
endif()
183+
endif()

bench/b2nd/jit-dsl.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#######################################################################
2+
# Copyright (c) 2019-present, Blosc Development Team <blosc@blosc.org>
3+
# All rights reserved.
4+
#
5+
# SPDX-License-Identifier: BSD-3-Clause
6+
#######################################################################
7+
8+
from __future__ import annotations
9+
10+
import argparse
11+
import contextlib
12+
import os
13+
import shutil
14+
import statistics
15+
import tempfile
16+
import time
17+
18+
import blosc2
19+
import numpy as np
20+
21+
22+
@blosc2.dsl_kernel
23+
def k_dsl(x, y):
24+
acc = x
25+
i = 0
26+
while i < 2:
27+
if i == 0:
28+
acc = acc + y
29+
else:
30+
acc = np.where(acc < y, acc + i, acc - i)
31+
i = i + 1
32+
return acc
33+
34+
35+
@blosc2.dsl_kernel
36+
def k_heavy_dsl(x, y, niter):
37+
acc = x
38+
i = 0
39+
while i < niter:
40+
t = np.sin(acc * 1.001 + y * 0.123)
41+
u = np.cos(acc * 0.777 - y * 0.211)
42+
v = np.exp(t * 0.25) - np.log(np.abs(u) + 1.0)
43+
p = np.sin(v * 0.731 + acc * 0.071)
44+
q = np.cos(v * 0.379 - y * 0.053)
45+
r = np.exp((p - q) * 0.17) - np.log(np.abs(p + q) + 1.0)
46+
w = np.sin((r + v) * 0.11) + np.cos((r - v) * 0.07)
47+
delta = v + r + w
48+
acc = np.where((acc < y), (acc + delta), (acc - delta))
49+
i = i + 1
50+
return acc
51+
52+
53+
@blosc2.dsl_kernel
54+
def k_arith_loop_dsl(x, y, niter):
55+
acc = x
56+
i = 0
57+
while i < niter:
58+
# Arithmetic-only recurrence intended to stress loop codegen.
59+
a1 = acc * 0.913 + y * 0.087
60+
a2 = a1 * 0.731 + acc * 0.269
61+
a3 = a2 * 0.619 + a1 * 0.381
62+
a4 = a3 * 0.541 + a2 * 0.459
63+
a5 = a4 * 0.503 + a3 * 0.497
64+
acc = (acc * 0.97) + (a5 * 0.03) + (i * 0.0000001)
65+
i = i + 1
66+
return acc
67+
68+
69+
@blosc2.dsl_kernel
70+
def mandelbrot_dsl(cr, ci, max_iter):
71+
zr = cr * 0.0
72+
zi = ci * 0.0
73+
i = 0
74+
while i < max_iter:
75+
zr2 = ((zr * zr) - (zi * zi)) + cr
76+
zi2 = (((zr * zi) * 2.0) + ci)
77+
zr = zr2
78+
zi = zi2
79+
i = i + 1
80+
# Mandelbrot-like iterate z <- z^2 + c (returns final magnitude proxy).
81+
return ((zr * zr) + (zi * zi))
82+
83+
84+
def _bench_cold_warm(fn, reps: int, warmup: int) -> tuple[float, float, float]:
85+
# First invocation: captures JIT compile/runtime setup cost when present.
86+
t0 = time.perf_counter()
87+
fn()
88+
cold = time.perf_counter() - t0
89+
90+
# Optional warmup happens after first call, so "cold" remains representative.
91+
for _ in range(warmup):
92+
fn()
93+
94+
times = []
95+
for _ in range(reps):
96+
t0 = time.perf_counter()
97+
fn()
98+
times.append(time.perf_counter() - t0)
99+
return cold, statistics.median(times), min(times)
100+
101+
102+
def _fmt(v: float) -> str:
103+
return f"{v:.6f}"
104+
105+
106+
@contextlib.contextmanager
107+
def _fresh_tmpdir(enabled: bool):
108+
if not enabled:
109+
yield
110+
return
111+
old_tmpdir = os.environ.get("TMPDIR")
112+
tmpdir = tempfile.mkdtemp(prefix="me-jit-bench-")
113+
os.environ["TMPDIR"] = tmpdir
114+
try:
115+
yield
116+
finally:
117+
if old_tmpdir is None:
118+
os.environ.pop("TMPDIR", None)
119+
else:
120+
os.environ["TMPDIR"] = old_tmpdir
121+
shutil.rmtree(tmpdir, ignore_errors=True)
122+
123+
124+
def main():
125+
parser = argparse.ArgumentParser(description="Benchmark JIT modes for expressions, reductions and DSL kernels.")
126+
parser.add_argument("--n", type=int, default=100_000, help="Array length.")
127+
parser.add_argument("--reps", type=int, default=2, help="Measured repetitions per workload/mode.")
128+
parser.add_argument("--warmup", type=int, default=1, help="Warmup runs per workload/mode.")
129+
parser.add_argument("--dtype", default="float64", choices=("float32", "float64"), help="Input dtype.")
130+
parser.add_argument("--clevel", type=int, default=1, help="Compression level for input arrays.")
131+
parser.add_argument("--heavy-iters", type=int, default=16, help="Iterations for the heavy DSL kernel.")
132+
parser.add_argument("--arith-iters", type=int, default=512, help="Iterations for the arithmetic loop DSL kernel.")
133+
parser.add_argument("--mandelbrot-iters", type=int, default=50, help="Iterations for Mandelbrot DSL kernel.")
134+
parser.add_argument(
135+
"--compiler",
136+
default="auto",
137+
choices=("auto", "tcc", "cc"),
138+
help="JIT backend override: auto (default), tcc, or cc.",
139+
)
140+
parser.add_argument(
141+
"--fresh-cache",
142+
action="store_true",
143+
help="Use a fresh TMPDIR per workload/mode row so cold_s includes actual JIT build cost.",
144+
)
145+
parser.add_argument("--trace", action="store_true", help="Print reminder for ME_DSL_TRACE usage.")
146+
args = parser.parse_args()
147+
148+
if args.trace:
149+
print("Tip: run with ME_DSL_TRACE=1 for backend/JIT diagnostics.")
150+
151+
dtype = np.dtype(args.dtype)
152+
jit_backend = None if args.compiler == "auto" else args.compiler
153+
cparams = blosc2.CParams(clevel=args.clevel, codec=blosc2.Codec.LZ4)
154+
155+
print(f"Building inputs: n={args.n:,}, dtype={dtype}, clevel={args.clevel}")
156+
a = blosc2.linspace(0.0, 1.0, args.n, dtype=dtype)
157+
b = blosc2.linspace(1.0, 2.0, args.n, dtype=dtype, cparams=cparams)
158+
cr = blosc2.linspace(-2.0, 1.0, args.n, dtype=dtype, cparams=cparams)
159+
ci = blosc2.linspace(-1.5, 1.5, args.n, dtype=dtype, cparams=cparams)
160+
161+
modes = [("auto", None), ("on", True), ("off", False)]
162+
rows = []
163+
164+
for mode_name, jit in modes:
165+
with _fresh_tmpdir(args.fresh_cache):
166+
cold, med, best = _bench_cold_warm(
167+
lambda: blosc2.sin(a + 0.5).compute(jit=jit, jit_backend=jit_backend), args.reps, args.warmup
168+
)
169+
rows.append(("compute_expr", mode_name, cold, med, best))
170+
171+
with _fresh_tmpdir(args.fresh_cache):
172+
cold, med, best = _bench_cold_warm(
173+
lambda: blosc2.sin(a + 0.5).sum(jit=jit, jit_backend=jit_backend), args.reps, args.warmup
174+
)
175+
rows.append(("reduce_sum", mode_name, cold, med, best))
176+
177+
with _fresh_tmpdir(args.fresh_cache):
178+
cold, med, best = _bench_cold_warm(
179+
lambda: blosc2.lazyudf(k_dsl, (a, b), dtype=dtype, jit=jit, jit_backend=jit_backend).compute(),
180+
args.reps,
181+
args.warmup,
182+
)
183+
rows.append(("lazyudf_dsl", mode_name, cold, med, best))
184+
185+
with _fresh_tmpdir(args.fresh_cache):
186+
cold, med, best = _bench_cold_warm(
187+
lambda: blosc2.lazyudf(
188+
k_heavy_dsl,
189+
(a, b, args.heavy_iters),
190+
dtype=dtype,
191+
jit=jit,
192+
jit_backend=jit_backend,
193+
).compute(),
194+
args.reps,
195+
args.warmup,
196+
)
197+
rows.append(("lazyudf_heavy", mode_name, cold, med, best))
198+
199+
with _fresh_tmpdir(args.fresh_cache):
200+
cold, med, best = _bench_cold_warm(
201+
lambda: blosc2.lazyudf(
202+
k_arith_loop_dsl,
203+
(a, b, args.arith_iters),
204+
dtype=dtype,
205+
jit=jit,
206+
jit_backend=jit_backend,
207+
).compute(),
208+
args.reps,
209+
args.warmup,
210+
)
211+
rows.append(("udf_arith", mode_name, cold, med, best))
212+
213+
with _fresh_tmpdir(args.fresh_cache):
214+
cold, med, best = _bench_cold_warm(
215+
lambda: blosc2.lazyudf(
216+
mandelbrot_dsl,
217+
(cr, ci, args.mandelbrot_iters),
218+
dtype=dtype,
219+
jit=jit,
220+
jit_backend=jit_backend,
221+
).compute(),
222+
args.reps,
223+
args.warmup,
224+
)
225+
rows.append(("mandelbrot_dsl", mode_name, cold, med, best))
226+
227+
warm_baseline = {}
228+
cold_baseline = {}
229+
for workload, mode_name, cold, med, _best in rows:
230+
if mode_name == "off":
231+
warm_baseline[workload] = med
232+
cold_baseline[workload] = cold
233+
234+
print(f"\nbackend: {args.compiler}")
235+
print("workload mode cold_s warm_med_s best_s warm_speedup cold_speedup")
236+
print("-----------------------------------------------------------------------------------")
237+
for workload, mode_name, cold, med, best in rows:
238+
warm_base = warm_baseline.get(workload)
239+
cold_base = cold_baseline.get(workload)
240+
warm_speedup = (warm_base / med) if warm_base else 1.0
241+
cold_speedup = (cold_base / cold) if cold_base else 1.0
242+
print(
243+
f"{workload:<14} {mode_name:<5} {_fmt(cold):>8} {_fmt(med):>8} {_fmt(best):>8} "
244+
f"{warm_speedup:>8.3f}x {cold_speedup:>8.3f}x"
245+
)
246+
247+
248+
if __name__ == "__main__":
249+
main()

0 commit comments

Comments
 (0)