Skip to content

Commit e039f41

Browse files
committed
New benchmarks for transcendental functions
1 parent 0d874c6 commit e039f41

1 file changed

Lines changed: 67 additions & 6 deletions

File tree

bench/ndarray/dsl-kernel-bench.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#######################################################################
77

88
import contextlib
9+
import argparse
910
import time
1011

1112
import numpy as np
@@ -111,6 +112,26 @@ def expr_nested2() -> str:
111112
return " + ".join(terms)
112113

113114

115+
def expr_transcendentals() -> str:
116+
return "log(exp(x) + tanh(x) + log1p(abs(x)) + sqrt(abs(x)) + expm1(x))"
117+
118+
119+
def expr_transcend1() -> str:
120+
return "log(exp(x))"
121+
122+
123+
def expr_transcend2() -> str:
124+
return "tanh(x)"
125+
126+
127+
def expr_transcend3() -> str:
128+
return "log1p(abs(x))"
129+
130+
131+
def expr_sincos_identity() -> str:
132+
return "sin(x) ** 2 + cos(x) ** 2"
133+
134+
114135
@contextlib.contextmanager
115136
def miniexpr_enabled(enabled: bool):
116137
old = lazyexpr_mod.try_miniexpr
@@ -213,6 +234,10 @@ def format_row(row):
213234

214235

215236
def main():
237+
parser = argparse.ArgumentParser()
238+
parser.add_argument("--transcend", action="store_true", help="Run only the transcendental lazyexpr cases")
239+
args = parser.parse_args()
240+
216241
n = 10_000
217242
dtype = np.float32
218243
cparams = blosc2.CParams(codec=blosc2.Codec.BLOSCLZ, clevel=1)
@@ -229,12 +254,48 @@ def main():
229254
("nested2", kernel_nested2, expr_nested2()),
230255
]
231256

232-
headers, fmt, sep = table_formatter()
233-
print(fmt.format(*headers), flush=True)
234-
print(sep, flush=True)
235-
for name, kernel, expr in cases:
236-
row = bench_case(name, kernel, expr, a, b, dtype, gb)
237-
print(fmt.format(*format_row(row)), flush=True)
257+
transcendental_cases = [
258+
("transcend1", expr_transcend1()),
259+
("transcend2", expr_transcend2()),
260+
("transcend3", expr_transcend3()),
261+
("transcend4", expr_transcendentals()),
262+
("sincos_id", expr_sincos_identity()),
263+
]
264+
265+
if not args.transcend:
266+
headers, fmt, sep = table_formatter()
267+
print(fmt.format(*headers), flush=True)
268+
print(sep, flush=True)
269+
for name, kernel, expr in cases:
270+
row = bench_case(name, kernel, expr, a, b, dtype, gb)
271+
print(fmt.format(*format_row(row)), flush=True)
272+
273+
if not args.transcend:
274+
print()
275+
print("Transcendental lazyexpr cases", flush=True)
276+
print("Case |Base ms|Base GB/s|Expr ms|Expr GB/s|Expr/Base", flush=True)
277+
print("------------+-------+---------+-------+---------+---------", flush=True)
278+
with miniexpr_enabled(False):
279+
for name, expr in transcendental_cases:
280+
lazy_expr_base = blosc2.lazyexpr(expr, {"x": a})
281+
res_base = lazy_expr_base.compute()
282+
base_time, _ = time_it(lambda: lazy_expr_base.compute())
283+
284+
with miniexpr_enabled(True):
285+
lazy_expr_fast = blosc2.lazyexpr(expr, {"x": a})
286+
res_fast = lazy_expr_fast.compute()
287+
expr_time, _ = time_it(lambda: lazy_expr_fast.compute())
288+
289+
np.testing.assert_allclose(res_fast[...], res_base[...], rtol=1e-5, atol=2e-6)
290+
print(
291+
f"{name:<12}|"
292+
f"{base_time * 1000:>7.2f}|"
293+
f"{(a.nbytes * 2 / 1e9) / base_time:>9.2f}|"
294+
f"{expr_time * 1000:>7.2f}|"
295+
f"{(a.nbytes * 2 / 1e9) / expr_time:>9.2f}|"
296+
f"{base_time / expr_time:>8.2f}x",
297+
flush=True,
298+
)
238299

239300

240301
if __name__ == "__main__":

0 commit comments

Comments
 (0)