@@ -56,6 +56,7 @@ def set_path_mode(mode: str) -> bool:
5656
5757def run_case (
5858 mode : str ,
59+ block_backend : str ,
5960 repeats : int ,
6061 shape_a : tuple [int , ...],
6162 shape_b : tuple [int , ...],
@@ -73,8 +74,10 @@ def run_case(
7374 warnings .simplefilter ("ignore" , RuntimeWarning )
7475 expected = np .matmul (a_np , b_np )
7576 original_flag = set_path_mode (mode )
77+ original_block_backend = blosc2 .blosc2_ext .get_matmul_block_backend ()
7678 original_set_pref_matmul = blosc2 .NDArray ._set_pref_matmul
7779 selected_paths = []
80+ selected_block_backend = None
7881 times = []
7982 result = None
8083
@@ -83,7 +86,9 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
8386 return original_set_pref_matmul (self , inputs , fp_accuracy )
8487
8588 blosc2 .NDArray ._set_pref_matmul = wrapped_set_pref_matmul
89+ blosc2 .blosc2_ext .set_matmul_block_backend (block_backend )
8690 try :
91+ selected_block_backend = blosc2 .blosc2_ext .get_selected_matmul_block_backend ()
8792 for _ in range (repeats ):
8893 before = len (selected_paths )
8994 t0 = time .perf_counter ()
@@ -97,6 +102,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
97102 finally :
98103 blosc2 .NDArray ._set_pref_matmul = original_set_pref_matmul
99104 linalg .try_miniexpr = original_flag
105+ blosc2 .blosc2_ext .set_matmul_block_backend (original_block_backend )
100106
101107 if result is None :
102108 raise RuntimeError ("matmul did not produce a result" )
@@ -114,6 +120,8 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
114120 "gflops_best" : expected_gflops (shape_a , shape_b , best ),
115121 "gflops_median" : expected_gflops (shape_a , shape_b , median ),
116122 "correct" : True ,
123+ "configured_block_backend" : block_backend ,
124+ "selected_block_backend" : selected_block_backend ,
117125 "selected_paths" : selected_paths ,
118126 "selected_path" : selected_paths [0 ] if selected_paths and len (set (selected_paths )) == 1 else "mixed" ,
119127 }
@@ -132,6 +140,12 @@ def main() -> None:
132140 parser .add_argument ("--blocks-out" , default = "100,100" , help = "Comma-separated block shape for output." )
133141 parser .add_argument ("--repeats" , type = int , default = 250 )
134142 parser .add_argument ("--modes" , nargs = "+" , default = ["chunked" , "fast" , "auto" ], choices = ["chunked" , "fast" , "auto" ])
143+ parser .add_argument (
144+ "--block-backend" ,
145+ default = "auto" ,
146+ choices = ["auto" , "naive" , "accelerate" ],
147+ help = "Kernel backend for the fast matmul block path." ,
148+ )
135149 parser .add_argument ("--json" , action = "store_true" , help = "Emit full JSON instead of a compact text summary." )
136150 args = parser .parse_args ()
137151
@@ -150,6 +164,7 @@ def main() -> None:
150164 results .append (
151165 run_case (
152166 mode ,
167+ args .block_backend ,
153168 args .repeats ,
154169 shape_a ,
155170 shape_b ,
@@ -173,6 +188,7 @@ def main() -> None:
173188 "blocks_b" : blocks_b ,
174189 "chunks_out" : chunks_out ,
175190 "blocks_out" : blocks_out ,
191+ "block_backend" : args .block_backend ,
176192 "results" : results ,
177193 }
178194
@@ -184,36 +200,27 @@ def main() -> None:
184200 print (json .dumps (summary , indent = 2 , sort_keys = True ))
185201 return
186202
187- print (
188- "case" ,
189- json .dumps (
190- {
191- "shape_a" : shape_a ,
192- "shape_b" : shape_b ,
193- "dtype" : str (dtype ),
194- "chunks_out" : chunks_out ,
195- "blocks_out" : blocks_out ,
196- },
197- sort_keys = True ,
198- ),
199- )
203+ print ("Matmul path comparison" )
204+ print (f" A shape: { shape_a } " )
205+ print (f" B shape: { shape_b } " )
206+ print (f" dtype: { dtype } " )
207+ print (f" chunks A/B/out: { chunks_a } / { chunks_b } / { chunks_out } " )
208+ print (f" blocks A/B/out: { blocks_a } / { blocks_b } / { blocks_out } " )
209+ print (f" repeats: { args .repeats } " )
210+ print (f" fast block backend: { args .block_backend } " )
200211 for item in results :
212+ gflops_best = "-" if item ["gflops_best" ] is None else f"{ item ['gflops_best' ]:.3f} "
201213 print (
202- "result" ,
203- json .dumps (
204- {
205- "mode" : item ["mode" ],
206- "best_s" : round (item ["best_s" ], 6 ),
207- "median_s" : round (item ["median_s" ], 6 ),
208- "gflops_best" : None if item ["gflops_best" ] is None else round (item ["gflops_best" ], 3 ),
209- "correct" : item ["correct" ],
210- "selected_path" : item ["selected_path" ],
211- },
212- sort_keys = True ,
213- ),
214+ f"{ item ['mode' ]:>7} : "
215+ f"best={ item ['best_s' ]:.6f} s "
216+ f"median={ item ['median_s' ]:.6f} s "
217+ f"gflops={ gflops_best } "
218+ f"path={ item ['selected_path' ]} "
219+ f"block_backend={ item ['selected_block_backend' ]} "
220+ f"correct={ item ['correct' ]} "
214221 )
215222 if "speedup_fast_vs_chunked" in summary :
216- print ("speedup" , json . dumps ({ "fast_vs_chunked" : round ( summary [" speedup_fast_vs_chunked" ], 3 )}, sort_keys = True ) )
223+ print (f"Speedup fast vs chunked: { summary [' speedup_fast_vs_chunked' ]:.3f } x" )
217224
218225
219226if __name__ == "__main__" :
0 commit comments