66#######################################################################
77
88import contextlib
9+ import argparse
910import time
1011
1112import numpy as np
@@ -111,6 +112,26 @@ def expr_nested2() -> str:
111112 return " + " .join (terms )
112113
113114
115+ def expr_transcendentals () -> str :
116+ return "log(exp(x) + tanh(x) + log1p(abs(x)) + sqrt(abs(x)) + expm1(x))"
117+
118+
119+ def expr_transcend1 () -> str :
120+ return "log(exp(x))"
121+
122+
123+ def expr_transcend2 () -> str :
124+ return "tanh(x)"
125+
126+
127+ def expr_transcend3 () -> str :
128+ return "log1p(abs(x))"
129+
130+
131+ def expr_sincos_identity () -> str :
132+ return "sin(x) ** 2 + cos(x) ** 2"
133+
134+
114135@contextlib .contextmanager
115136def miniexpr_enabled (enabled : bool ):
116137 old = lazyexpr_mod .try_miniexpr
@@ -213,6 +234,10 @@ def format_row(row):
213234
214235
215236def main ():
237+ parser = argparse .ArgumentParser ()
238+ parser .add_argument ("--transcend" , action = "store_true" , help = "Run only the transcendental lazyexpr cases" )
239+ args = parser .parse_args ()
240+
216241 n = 10_000
217242 dtype = np .float32
218243 cparams = blosc2 .CParams (codec = blosc2 .Codec .BLOSCLZ , clevel = 1 )
@@ -229,12 +254,48 @@ def main():
229254 ("nested2" , kernel_nested2 , expr_nested2 ()),
230255 ]
231256
232- headers , fmt , sep = table_formatter ()
233- print (fmt .format (* headers ), flush = True )
234- print (sep , flush = True )
235- for name , kernel , expr in cases :
236- row = bench_case (name , kernel , expr , a , b , dtype , gb )
237- print (fmt .format (* format_row (row )), flush = True )
257+ transcendental_cases = [
258+ ("transcend1" , expr_transcend1 ()),
259+ ("transcend2" , expr_transcend2 ()),
260+ ("transcend3" , expr_transcend3 ()),
261+ ("transcend4" , expr_transcendentals ()),
262+ ("sincos_id" , expr_sincos_identity ()),
263+ ]
264+
265+ if not args .transcend :
266+ headers , fmt , sep = table_formatter ()
267+ print (fmt .format (* headers ), flush = True )
268+ print (sep , flush = True )
269+ for name , kernel , expr in cases :
270+ row = bench_case (name , kernel , expr , a , b , dtype , gb )
271+ print (fmt .format (* format_row (row )), flush = True )
272+
273+ if not args .transcend :
274+ print ()
275+ print ("Transcendental lazyexpr cases" , flush = True )
276+ print ("Case |Base ms|Base GB/s|Expr ms|Expr GB/s|Expr/Base" , flush = True )
277+ print ("------------+-------+---------+-------+---------+---------" , flush = True )
278+ with miniexpr_enabled (False ):
279+ for name , expr in transcendental_cases :
280+ lazy_expr_base = blosc2 .lazyexpr (expr , {"x" : a })
281+ res_base = lazy_expr_base .compute ()
282+ base_time , _ = time_it (lambda : lazy_expr_base .compute ())
283+
284+ with miniexpr_enabled (True ):
285+ lazy_expr_fast = blosc2 .lazyexpr (expr , {"x" : a })
286+ res_fast = lazy_expr_fast .compute ()
287+ expr_time , _ = time_it (lambda : lazy_expr_fast .compute ())
288+
289+ np .testing .assert_allclose (res_fast [...], res_base [...], rtol = 1e-5 , atol = 2e-6 )
290+ print (
291+ f"{ name :<12} |"
292+ f"{ base_time * 1000 :>7.2f} |"
293+ f"{ (a .nbytes * 2 / 1e9 ) / base_time :>9.2f} |"
294+ f"{ expr_time * 1000 :>7.2f} |"
295+ f"{ (a .nbytes * 2 / 1e9 ) / expr_time :>9.2f} |"
296+ f"{ base_time / expr_time :>8.2f} x" ,
297+ flush = True ,
298+ )
238299
239300
240301if __name__ == "__main__" :
0 commit comments