@@ -55,7 +55,10 @@ def set_path_mode(mode: str) -> bool:
5555
5656
5757def run_case (
58+ label : str ,
5859 mode : str ,
60+ block_backend : str ,
61+ warmup : int ,
5962 repeats : int ,
6063 shape_a : tuple [int , ...],
6164 shape_b : tuple [int , ...],
@@ -73,8 +76,10 @@ def run_case(
7376 warnings .simplefilter ("ignore" , RuntimeWarning )
7477 expected = np .matmul (a_np , b_np )
7578 original_flag = set_path_mode (mode )
79+ original_block_backend = blosc2 .blosc2_ext .get_matmul_block_backend ()
7680 original_set_pref_matmul = blosc2 .NDArray ._set_pref_matmul
7781 selected_paths = []
82+ selected_block_backend = None
7883 times = []
7984 result = None
8085
@@ -83,7 +88,17 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
8388 return original_set_pref_matmul (self , inputs , fp_accuracy )
8489
8590 blosc2 .NDArray ._set_pref_matmul = wrapped_set_pref_matmul
91+ blosc2 .blosc2_ext .set_matmul_block_backend (block_backend )
8692 try :
93+ selected_block_backend = blosc2 .blosc2_ext .get_selected_matmul_block_backend ()
94+ for _ in range (warmup ):
95+ before = len (selected_paths )
96+ with warnings .catch_warnings ():
97+ # NumPy + Accelerate can emit spurious matmul RuntimeWarnings on macOS arm64.
98+ warnings .simplefilter ("ignore" , RuntimeWarning )
99+ result = blosc2 .matmul (a , b , chunks = chunks_out , blocks = blocks_out )
100+ if len (selected_paths ) == before :
101+ selected_paths .append ("chunked" )
87102 for _ in range (repeats ):
88103 before = len (selected_paths )
89104 t0 = time .perf_counter ()
@@ -97,6 +112,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
97112 finally :
98113 blosc2 .NDArray ._set_pref_matmul = original_set_pref_matmul
99114 linalg .try_miniexpr = original_flag
115+ blosc2 .blosc2_ext .set_matmul_block_backend (original_block_backend )
100116
101117 if result is None :
102118 raise RuntimeError ("matmul did not produce a result" )
@@ -106,32 +122,90 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
106122
107123 best = min (times )
108124 median = statistics .median (times )
125+ selected_path = selected_paths [0 ] if selected_paths and len (set (selected_paths )) == 1 else "mixed"
126+ reported_block_backend = selected_block_backend if selected_path != "chunked" else None
109127 return {
128+ "label" : label ,
110129 "mode" : mode ,
111130 "times_s" : times ,
112131 "best_s" : best ,
113132 "median_s" : median ,
114133 "gflops_best" : expected_gflops (shape_a , shape_b , best ),
115134 "gflops_median" : expected_gflops (shape_a , shape_b , median ),
116135 "correct" : True ,
136+ "configured_block_backend" : block_backend ,
137+ "selected_block_backend" : reported_block_backend ,
117138 "selected_paths" : selected_paths ,
118- "selected_path" : selected_paths [0 ] if selected_paths and len (set (selected_paths )) == 1 else "mixed" ,
139+ "selected_path" : selected_path ,
140+ }
141+
142+
143+ def run_numpy_case (
144+ warmup : int ,
145+ repeats : int ,
146+ shape_a : tuple [int , ...],
147+ shape_b : tuple [int , ...],
148+ dtype : np .dtype ,
149+ chunks_a : tuple [int , ...] | None ,
150+ chunks_b : tuple [int , ...] | None ,
151+ blocks_a : tuple [int , ...] | None ,
152+ blocks_b : tuple [int , ...] | None ,
153+ ):
154+ _ , _ , a_np , b_np = build_arrays (shape_a , shape_b , dtype , chunks_a , chunks_b , blocks_a , blocks_b )
155+ times = []
156+ result = None
157+ for _ in range (warmup ):
158+ with warnings .catch_warnings ():
159+ warnings .simplefilter ("ignore" , RuntimeWarning )
160+ result = np .matmul (a_np , b_np )
161+ for _ in range (repeats ):
162+ t0 = time .perf_counter ()
163+ with warnings .catch_warnings ():
164+ warnings .simplefilter ("ignore" , RuntimeWarning )
165+ result = np .matmul (a_np , b_np )
166+ times .append (time .perf_counter () - t0 )
167+
168+ if result is None :
169+ raise RuntimeError ("numpy.matmul did not produce a result" )
170+
171+ best = min (times )
172+ median = statistics .median (times )
173+ return {
174+ "label" : "numpy" ,
175+ "mode" : "numpy" ,
176+ "times_s" : times ,
177+ "best_s" : best ,
178+ "median_s" : median ,
179+ "gflops_best" : expected_gflops (shape_a , shape_b , best ),
180+ "gflops_median" : expected_gflops (shape_a , shape_b , median ),
181+ "correct" : True ,
182+ "configured_block_backend" : None ,
183+ "selected_block_backend" : None ,
184+ "selected_paths" : ["numpy" ] * repeats ,
185+ "selected_path" : "numpy" ,
119186 }
120187
121188
122189def main () -> None :
123190 parser = argparse .ArgumentParser (description = "Compare chunked and fast blosc2.matmul paths." )
124- parser .add_argument ("--shape-a" , default = "400,400 " , help = "Comma-separated shape for A." )
125- parser .add_argument ("--shape-b" , default = "400,400 " , help = "Comma-separated shape for B." )
191+ parser .add_argument ("--shape-a" , default = "2000,2000 " , help = "Comma-separated shape for A." )
192+ parser .add_argument ("--shape-b" , default = "2000,2000 " , help = "Comma-separated shape for B." )
126193 parser .add_argument ("--dtype" , default = "float32" , choices = ["float32" , "float64" , "int32" , "int64" ])
127- parser .add_argument ("--chunks-a" , default = "200,200 " , help = "Comma-separated chunk shape for A." )
128- parser .add_argument ("--chunks-b" , default = "200,200 " , help = "Comma-separated chunk shape for B." )
194+ parser .add_argument ("--chunks-a" , default = "500,500 " , help = "Comma-separated chunk shape for A." )
195+ parser .add_argument ("--chunks-b" , default = "500,500 " , help = "Comma-separated chunk shape for B." )
129196 parser .add_argument ("--blocks-a" , default = "100,100" , help = "Comma-separated block shape for A." )
130197 parser .add_argument ("--blocks-b" , default = "100,100" , help = "Comma-separated block shape for B." )
131- parser .add_argument ("--chunks-out" , default = "200,200 " , help = "Comma-separated chunk shape for output." )
198+ parser .add_argument ("--chunks-out" , default = "500,500 " , help = "Comma-separated chunk shape for output." )
132199 parser .add_argument ("--blocks-out" , default = "100,100" , help = "Comma-separated block shape for output." )
133- parser .add_argument ("--repeats" , type = int , default = 250 )
200+ parser .add_argument ("--warmup" , type = int , default = 2 )
201+ parser .add_argument ("--repeats" , type = int , default = 3 )
134202 parser .add_argument ("--modes" , nargs = "+" , default = ["chunked" , "fast" , "auto" ], choices = ["chunked" , "fast" , "auto" ])
203+ parser .add_argument (
204+ "--block-backend" ,
205+ default = "auto" ,
206+ choices = ["auto" , "naive" , "accelerate" , "cblas" ],
207+ help = "Kernel backend for the fast matmul block path." ,
208+ )
135209 parser .add_argument ("--json" , action = "store_true" , help = "Emit full JSON instead of a compact text summary." )
136210 args = parser .parse_args ()
137211
@@ -145,11 +219,27 @@ def main() -> None:
145219 blocks_out = parse_int_tuple (args .blocks_out ) if args .blocks_out else None
146220 dtype = np .dtype (args .dtype )
147221
222+ print ("Matmul path comparison" )
223+ print (f" A shape: { shape_a } " )
224+ print (f" B shape: { shape_b } " )
225+ print (f" dtype: { dtype } " )
226+ print (f" chunks A/B/out: { chunks_a } / { chunks_b } / { chunks_out } " )
227+ print (f" blocks A/B/out: { blocks_a } / { blocks_b } / { blocks_out } " )
228+ print (f" warmup: { args .warmup } " )
229+ print (f" repeats: { args .repeats } " )
230+ print (f" fast block backend: { args .block_backend } " )
231+ print (f" matmul library: { blosc2 .get_matmul_library ()} " )
232+ print ()
233+ print ("Results:" )
234+
148235 results = []
149236 for mode in args .modes :
150237 results .append (
151238 run_case (
152239 mode ,
240+ mode ,
241+ args .block_backend ,
242+ args .warmup ,
153243 args .repeats ,
154244 shape_a ,
155245 shape_b ,
@@ -163,6 +253,42 @@ def main() -> None:
163253 )
164254 )
165255
256+ if args .block_backend == "auto" and "fast" in args .modes :
257+ fast_naive = run_case (
258+ "fast-naive" ,
259+ "fast" ,
260+ "naive" ,
261+ args .warmup ,
262+ args .repeats ,
263+ shape_a ,
264+ shape_b ,
265+ dtype ,
266+ chunks_a ,
267+ chunks_b ,
268+ blocks_a ,
269+ blocks_b ,
270+ chunks_out ,
271+ blocks_out ,
272+ )
273+ if fast_naive ["selected_block_backend" ] != next (
274+ item ["selected_block_backend" ] for item in results if item ["mode" ] == "fast"
275+ ):
276+ results .append (fast_naive )
277+
278+ results .append (
279+ run_numpy_case (
280+ args .warmup ,
281+ args .repeats ,
282+ shape_a ,
283+ shape_b ,
284+ dtype ,
285+ chunks_a ,
286+ chunks_b ,
287+ blocks_a ,
288+ blocks_b ,
289+ )
290+ )
291+
166292 summary = {
167293 "shape_a" : shape_a ,
168294 "shape_b" : shape_b ,
@@ -173,47 +299,55 @@ def main() -> None:
173299 "blocks_b" : blocks_b ,
174300 "chunks_out" : chunks_out ,
175301 "blocks_out" : blocks_out ,
302+ "block_backend" : args .block_backend ,
176303 "results" : results ,
177304 }
178305
179- best_by_mode = {item ["mode" ]: item ["best_s" ] for item in results }
180- if "chunked" in best_by_mode and "fast" in best_by_mode :
181- summary ["speedup_fast_vs_chunked" ] = best_by_mode ["chunked" ] / best_by_mode ["fast" ]
306+ best_by_label = {item ["label" ]: item ["best_s" ] for item in results }
307+ if "chunked" in best_by_label and "fast" in best_by_label :
308+ summary ["speedup_fast_vs_chunked" ] = best_by_label ["chunked" ] / best_by_label ["fast" ]
309+ if "chunked" in best_by_label and "fast-naive" in best_by_label :
310+ summary ["speedup_fast_naive_vs_chunked" ] = best_by_label ["chunked" ] / best_by_label ["fast-naive" ]
311+ if "fast" in best_by_label and "fast-naive" in best_by_label :
312+ summary ["speedup_fast_vs_fast_naive" ] = best_by_label ["fast-naive" ] / best_by_label ["fast" ]
313+ if "numpy" in best_by_label and "fast" in best_by_label :
314+ summary ["speedup_fast_vs_numpy" ] = best_by_label ["numpy" ] / best_by_label ["fast" ]
315+ if "numpy" in best_by_label and "auto" in best_by_label :
316+ summary ["speedup_auto_vs_numpy" ] = best_by_label ["numpy" ] / best_by_label ["auto" ]
182317
183318 if args .json :
184319 print (json .dumps (summary , indent = 2 , sort_keys = True ))
185320 return
186321
187- print (
188- "case" ,
189- json .dumps (
190- {
191- "shape_a" : shape_a ,
192- "shape_b" : shape_b ,
193- "dtype" : str (dtype ),
194- "chunks_out" : chunks_out ,
195- "blocks_out" : blocks_out ,
196- },
197- sort_keys = True ,
198- ),
199- )
200- for item in results :
322+ display_order = ["chunked" , "fast-naive" , "fast" , "auto" , "numpy" ]
323+ ordered_results = sorted (results , key = lambda item : display_order .index (item ["label" ]) if item ["label" ] in display_order else len (display_order ))
324+
325+ for item in ordered_results :
326+ gflops_best = "-" if item ["gflops_best" ] is None else f"{ item ['gflops_best' ]:.3f} "
327+ if item ["label" ] == "numpy" :
328+ backend_info = f"library={ blosc2 .get_matmul_library ()} "
329+ else :
330+ block_backend = item ["selected_block_backend" ] if item ["selected_block_backend" ] is not None else "-"
331+ backend_info = f"block_backend={ block_backend } "
201332 print (
202- "result" ,
203- json .dumps (
204- {
205- "mode" : item ["mode" ],
206- "best_s" : round (item ["best_s" ], 6 ),
207- "median_s" : round (item ["median_s" ], 6 ),
208- "gflops_best" : None if item ["gflops_best" ] is None else round (item ["gflops_best" ], 3 ),
209- "correct" : item ["correct" ],
210- "selected_path" : item ["selected_path" ],
211- },
212- sort_keys = True ,
213- ),
333+ f"{ item ['label' ]:>10} : "
334+ f"best={ item ['best_s' ]:.6f} s "
335+ f"median={ item ['median_s' ]:.6f} s "
336+ f"gflops={ gflops_best } "
337+ f"path={ item ['selected_path' ]} "
338+ f"{ backend_info } "
339+ f"correct={ item ['correct' ]} "
214340 )
215341 if "speedup_fast_vs_chunked" in summary :
216- print ("speedup" , json .dumps ({"fast_vs_chunked" : round (summary ["speedup_fast_vs_chunked" ], 3 )}, sort_keys = True ))
342+ print (f"Speedup fast vs chunked: { summary ['speedup_fast_vs_chunked' ]:.3f} x" )
343+ if "speedup_fast_naive_vs_chunked" in summary :
344+ print (f"Speedup fast-naive vs chunked: { summary ['speedup_fast_naive_vs_chunked' ]:.3f} x" )
345+ if "speedup_fast_vs_fast_naive" in summary :
346+ print (f"Speedup fast vs fast-naive: { summary ['speedup_fast_vs_fast_naive' ]:.3f} x" )
347+ if "speedup_fast_vs_numpy" in summary :
348+ print (f"Speedup fast vs numpy: { summary ['speedup_fast_vs_numpy' ]:.3f} x" )
349+ if "speedup_auto_vs_numpy" in summary :
350+ print (f"Speedup auto vs numpy: { summary ['speedup_auto_vs_numpy' ]:.3f} x" )
217351
218352
219353if __name__ == "__main__" :
0 commit comments