1616
1717lazyexpr_mod = importlib .import_module ("blosc2.lazyexpr" )
1818where = np .where
19+ sin = np .sin
20+ cos = np .cos
21+ tanh = np .tanh
22+ sqrt = np .sqrt
23+ exp = np .exp
24+ expm1 = np .expm1
25+ log = np .log
26+ log1p = np .log1p
27+ abs = np .abs
28+
29+ DSL_JIT = True
30+ DSL_JIT_BACKEND = "tcc"
1931
2032
2133@blosc2 .dsl_kernel
@@ -132,6 +144,31 @@ def expr_sincos_identity() -> str:
132144 return "sin(x) ** 2 + cos(x) ** 2"
133145
134146
147+ @blosc2 .dsl_kernel
148+ def kernel_transcend1 (x ):
149+ return log (exp (x ))
150+
151+
152+ @blosc2 .dsl_kernel
153+ def kernel_transcend2 (x ):
154+ return tanh (x )
155+
156+
157+ @blosc2 .dsl_kernel
158+ def kernel_transcend3 (x ):
159+ return log1p (abs (x ))
160+
161+
162+ @blosc2 .dsl_kernel
163+ def kernel_transcend4 (x ):
164+ return log (exp (x ) + tanh (x ) + log1p (abs (x )) + sqrt (abs (x )) + expm1 (x ))
165+
166+
167+ @blosc2 .dsl_kernel
168+ def kernel_sincos_identity (x ):
169+ return sin (x ) ** 2 + cos (x ) ** 2
170+
171+
135172@contextlib .contextmanager
136173def miniexpr_enabled (enabled : bool ):
137174 old = lazyexpr_mod .try_miniexpr
@@ -152,6 +189,37 @@ def time_it(fn, niter=3):
152189 return best , out
153190
154191
192+ def bench_transcend_case (name , kernel , expr , a ):
193+ gb = a .nbytes * 2 / 1e9
194+
195+ with miniexpr_enabled (False ):
196+ lazy_expr_base = blosc2 .lazyexpr (expr , {"x" : a })
197+ res_base = lazy_expr_base .compute ()
198+ base_time , _ = time_it (lambda : lazy_expr_base .compute ())
199+
200+ with miniexpr_enabled (True ):
201+ lazy_expr_fast = blosc2 .lazyexpr (expr , {"x" : a })
202+ res_fast = lazy_expr_fast .compute ()
203+ expr_time , _ = time_it (lambda : lazy_expr_fast .compute ())
204+
205+ lazy_dsl = blosc2 .lazyudf (kernel , (a ,), dtype = a .dtype , jit = DSL_JIT , jit_backend = DSL_JIT_BACKEND )
206+ res_dsl = lazy_dsl .compute ()
207+ dsl_time , _ = time_it (lambda : lazy_dsl .compute ())
208+
209+ np .testing .assert_allclose (res_fast [...], res_base [...], rtol = 1e-5 , atol = 2e-6 )
210+ np .testing .assert_allclose (res_dsl [...], res_base [...], rtol = 1e-5 , atol = 2e-6 )
211+
212+ return {
213+ "case" : name ,
214+ "baseline" : base_time ,
215+ "lazyexpr" : expr_time ,
216+ "dsl" : dsl_time ,
217+ "baseline_gbps" : gb / base_time ,
218+ "lazyexpr_gbps" : gb / expr_time ,
219+ "dsl_gbps" : gb / dsl_time ,
220+ }
221+
222+
155223def bench_case (name , kernel , expr , a , b , dtype , gb ):
156224 if kernel .dsl_source is None :
157225 raise RuntimeError (f"DSL extraction failed for { name } " )
@@ -166,7 +234,7 @@ def bench_case(name, kernel, expr, a, b, dtype, gb):
166234 _ = lazy_expr_fast .compute ()
167235 expr_time , _ = time_it (lambda : lazy_expr_fast .compute ())
168236
169- lazy_dsl = blosc2 .lazyudf (kernel , (a , b ), dtype = dtype )
237+ lazy_dsl = blosc2 .lazyudf (kernel , (a , b ), dtype = dtype , jit = DSL_JIT , jit_backend = DSL_JIT_BACKEND )
170238 res_dsl = lazy_dsl .compute ()
171239 dsl_time , _ = time_it (lambda : lazy_dsl .compute ())
172240
@@ -238,7 +306,7 @@ def main():
238306 parser .add_argument ("--transcend" , action = "store_true" , help = "Run only the transcendental lazyexpr cases" )
239307 args = parser .parse_args ()
240308
241- n = 10_000
309+ n = 1_000
242310 dtype = np .float32
243311 cparams = blosc2 .CParams (codec = blosc2 .Codec .BLOSCLZ , clevel = 1 )
244312
@@ -255,11 +323,11 @@ def main():
255323 ]
256324
257325 transcendental_cases = [
258- ("transcend1" , expr_transcend1 ()),
259- ("transcend2" , expr_transcend2 ()),
260- ("transcend3" , expr_transcend3 ()),
261- ("transcend4" , expr_transcendentals ()),
262- ("sincos_id" , expr_sincos_identity ()),
326+ ("transcend1" , kernel_transcend1 , expr_transcend1 ()),
327+ ("transcend2" , kernel_transcend2 , expr_transcend2 ()),
328+ ("transcend3" , kernel_transcend3 , expr_transcend3 ()),
329+ ("transcend4" , kernel_transcend4 , expr_transcendentals ()),
330+ ("sincos_id" , kernel_sincos_identity , expr_sincos_identity ()),
263331 ]
264332
265333 if not args .transcend :
@@ -273,29 +341,12 @@ def main():
273341 if not args .transcend :
274342 print ()
275343 print ("Transcendental lazyexpr cases" , flush = True )
276- print ("Case |Base ms|Base GB/s|Expr ms|Expr GB/s|Expr/Base" , flush = True )
277- print ("------------+-------+---------+-------+---------+---------" , flush = True )
278- with miniexpr_enabled (False ):
279- for name , expr in transcendental_cases :
280- lazy_expr_base = blosc2 .lazyexpr (expr , {"x" : a })
281- res_base = lazy_expr_base .compute ()
282- base_time , _ = time_it (lambda : lazy_expr_base .compute ())
283-
284- with miniexpr_enabled (True ):
285- lazy_expr_fast = blosc2 .lazyexpr (expr , {"x" : a })
286- res_fast = lazy_expr_fast .compute ()
287- expr_time , _ = time_it (lambda : lazy_expr_fast .compute ())
288-
289- np .testing .assert_allclose (res_fast [...], res_base [...], rtol = 1e-5 , atol = 2e-6 )
290- print (
291- f"{ name :<12} |"
292- f"{ base_time * 1000 :>7.2f} |"
293- f"{ (a .nbytes * 2 / 1e9 ) / base_time :>9.2f} |"
294- f"{ expr_time * 1000 :>7.2f} |"
295- f"{ (a .nbytes * 2 / 1e9 ) / expr_time :>9.2f} |"
296- f"{ base_time / expr_time :>8.2f} x" ,
297- flush = True ,
298- )
344+ headers , fmt , sep = table_formatter ()
345+ print (fmt .format (* headers ), flush = True )
346+ print (sep , flush = True )
347+ for name , kernel , expr in transcendental_cases :
348+ row = bench_transcend_case (name , kernel , expr , a )
349+ print (fmt .format (* format_row (row )), flush = True )
299350
300351
301352if __name__ == "__main__" :
0 commit comments