Skip to content

Commit 9edf21a

Browse files
committed
Update DSL benchmark, and use latest miniexpr
1 parent 5d64101 commit 9edf21a

2 files changed

Lines changed: 82 additions & 31 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ endif()
6363

6464
FetchContent_Declare(miniexpr
6565
GIT_REPOSITORY https://github.com/Blosc/miniexpr.git
66-
GIT_TAG 9478199e66db99c17402f8d1ba6f8912c234adc4
66+
GIT_TAG 257f92ace014f514e2e31bbd95729123e479996f
6767
# SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../miniexpr
6868
)
6969
FetchContent_MakeAvailable(miniexpr)

bench/ndarray/dsl-kernel-bench.py

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616

1717
lazyexpr_mod = importlib.import_module("blosc2.lazyexpr")
1818
where = np.where
19+
sin = np.sin
20+
cos = np.cos
21+
tanh = np.tanh
22+
sqrt = np.sqrt
23+
exp = np.exp
24+
expm1 = np.expm1
25+
log = np.log
26+
log1p = np.log1p
27+
abs = np.abs
28+
29+
DSL_JIT = True
30+
DSL_JIT_BACKEND = "tcc"
1931

2032

2133
@blosc2.dsl_kernel
@@ -132,6 +144,31 @@ def expr_sincos_identity() -> str:
132144
return "sin(x) ** 2 + cos(x) ** 2"
133145

134146

147+
@blosc2.dsl_kernel
148+
def kernel_transcend1(x):
149+
return log(exp(x))
150+
151+
152+
@blosc2.dsl_kernel
153+
def kernel_transcend2(x):
154+
return tanh(x)
155+
156+
157+
@blosc2.dsl_kernel
158+
def kernel_transcend3(x):
159+
return log1p(abs(x))
160+
161+
162+
@blosc2.dsl_kernel
163+
def kernel_transcend4(x):
164+
return log(exp(x) + tanh(x) + log1p(abs(x)) + sqrt(abs(x)) + expm1(x))
165+
166+
167+
@blosc2.dsl_kernel
168+
def kernel_sincos_identity(x):
169+
return sin(x) ** 2 + cos(x) ** 2
170+
171+
135172
@contextlib.contextmanager
136173
def miniexpr_enabled(enabled: bool):
137174
old = lazyexpr_mod.try_miniexpr
@@ -152,6 +189,37 @@ def time_it(fn, niter=3):
152189
return best, out
153190

154191

192+
def bench_transcend_case(name, kernel, expr, a):
193+
gb = a.nbytes * 2 / 1e9
194+
195+
with miniexpr_enabled(False):
196+
lazy_expr_base = blosc2.lazyexpr(expr, {"x": a})
197+
res_base = lazy_expr_base.compute()
198+
base_time, _ = time_it(lambda: lazy_expr_base.compute())
199+
200+
with miniexpr_enabled(True):
201+
lazy_expr_fast = blosc2.lazyexpr(expr, {"x": a})
202+
res_fast = lazy_expr_fast.compute()
203+
expr_time, _ = time_it(lambda: lazy_expr_fast.compute())
204+
205+
lazy_dsl = blosc2.lazyudf(kernel, (a,), dtype=a.dtype, jit=DSL_JIT, jit_backend=DSL_JIT_BACKEND)
206+
res_dsl = lazy_dsl.compute()
207+
dsl_time, _ = time_it(lambda: lazy_dsl.compute())
208+
209+
np.testing.assert_allclose(res_fast[...], res_base[...], rtol=1e-5, atol=2e-6)
210+
np.testing.assert_allclose(res_dsl[...], res_base[...], rtol=1e-5, atol=2e-6)
211+
212+
return {
213+
"case": name,
214+
"baseline": base_time,
215+
"lazyexpr": expr_time,
216+
"dsl": dsl_time,
217+
"baseline_gbps": gb / base_time,
218+
"lazyexpr_gbps": gb / expr_time,
219+
"dsl_gbps": gb / dsl_time,
220+
}
221+
222+
155223
def bench_case(name, kernel, expr, a, b, dtype, gb):
156224
if kernel.dsl_source is None:
157225
raise RuntimeError(f"DSL extraction failed for {name}")
@@ -166,7 +234,7 @@ def bench_case(name, kernel, expr, a, b, dtype, gb):
166234
_ = lazy_expr_fast.compute()
167235
expr_time, _ = time_it(lambda: lazy_expr_fast.compute())
168236

169-
lazy_dsl = blosc2.lazyudf(kernel, (a, b), dtype=dtype)
237+
lazy_dsl = blosc2.lazyudf(kernel, (a, b), dtype=dtype, jit=DSL_JIT, jit_backend=DSL_JIT_BACKEND)
170238
res_dsl = lazy_dsl.compute()
171239
dsl_time, _ = time_it(lambda: lazy_dsl.compute())
172240

@@ -238,7 +306,7 @@ def main():
238306
parser.add_argument("--transcend", action="store_true", help="Run only the transcendental lazyexpr cases")
239307
args = parser.parse_args()
240308

241-
n = 10_000
309+
n = 1_000
242310
dtype = np.float32
243311
cparams = blosc2.CParams(codec=blosc2.Codec.BLOSCLZ, clevel=1)
244312

@@ -255,11 +323,11 @@ def main():
255323
]
256324

257325
transcendental_cases = [
258-
("transcend1", expr_transcend1()),
259-
("transcend2", expr_transcend2()),
260-
("transcend3", expr_transcend3()),
261-
("transcend4", expr_transcendentals()),
262-
("sincos_id", expr_sincos_identity()),
326+
("transcend1", kernel_transcend1, expr_transcend1()),
327+
("transcend2", kernel_transcend2, expr_transcend2()),
328+
("transcend3", kernel_transcend3, expr_transcend3()),
329+
("transcend4", kernel_transcend4, expr_transcendentals()),
330+
("sincos_id", kernel_sincos_identity, expr_sincos_identity()),
263331
]
264332

265333
if not args.transcend:
@@ -273,29 +341,12 @@ def main():
273341
if not args.transcend:
274342
print()
275343
print("Transcendental lazyexpr cases", flush=True)
276-
print("Case |Base ms|Base GB/s|Expr ms|Expr GB/s|Expr/Base", flush=True)
277-
print("------------+-------+---------+-------+---------+---------", flush=True)
278-
with miniexpr_enabled(False):
279-
for name, expr in transcendental_cases:
280-
lazy_expr_base = blosc2.lazyexpr(expr, {"x": a})
281-
res_base = lazy_expr_base.compute()
282-
base_time, _ = time_it(lambda: lazy_expr_base.compute())
283-
284-
with miniexpr_enabled(True):
285-
lazy_expr_fast = blosc2.lazyexpr(expr, {"x": a})
286-
res_fast = lazy_expr_fast.compute()
287-
expr_time, _ = time_it(lambda: lazy_expr_fast.compute())
288-
289-
np.testing.assert_allclose(res_fast[...], res_base[...], rtol=1e-5, atol=2e-6)
290-
print(
291-
f"{name:<12}|"
292-
f"{base_time * 1000:>7.2f}|"
293-
f"{(a.nbytes * 2 / 1e9) / base_time:>9.2f}|"
294-
f"{expr_time * 1000:>7.2f}|"
295-
f"{(a.nbytes * 2 / 1e9) / expr_time:>9.2f}|"
296-
f"{base_time / expr_time:>8.2f}x",
297-
flush=True,
298-
)
344+
headers, fmt, sep = table_formatter()
345+
print(fmt.format(*headers), flush=True)
346+
print(sep, flush=True)
347+
for name, kernel, expr in transcendental_cases:
348+
row = bench_transcend_case(name, kernel, expr, a)
349+
print(fmt.format(*format_row(row)), flush=True)
299350

300351

301352
if __name__ == "__main__":

0 commit comments

Comments
 (0)