@@ -55,6 +55,7 @@ def set_path_mode(mode: str) -> bool:
5555
5656
5757def run_case (
58+ label : str ,
5859 mode : str ,
5960 block_backend : str ,
6061 repeats : int ,
@@ -113,6 +114,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
113114 best = min (times )
114115 median = statistics .median (times )
115116 return {
117+ "label" : label ,
116118 "mode" : mode ,
117119 "times_s" : times ,
118120 "best_s" : best ,
@@ -163,6 +165,7 @@ def main() -> None:
163165 for mode in args .modes :
164166 results .append (
165167 run_case (
168+ mode ,
166169 mode ,
167170 args .block_backend ,
168171 args .repeats ,
@@ -178,6 +181,27 @@ def main() -> None:
178181 )
179182 )
180183
184+ if args .block_backend == "auto" and "fast" in args .modes :
185+ fast_naive = run_case (
186+ "fast-naive" ,
187+ "fast" ,
188+ "naive" ,
189+ args .repeats ,
190+ shape_a ,
191+ shape_b ,
192+ dtype ,
193+ chunks_a ,
194+ chunks_b ,
195+ blocks_a ,
196+ blocks_b ,
197+ chunks_out ,
198+ blocks_out ,
199+ )
200+ if fast_naive ["selected_block_backend" ] != next (
201+ item ["selected_block_backend" ] for item in results if item ["mode" ] == "fast"
202+ ):
203+ results .append (fast_naive )
204+
181205 summary = {
182206 "shape_a" : shape_a ,
183207 "shape_b" : shape_b ,
@@ -192,9 +216,13 @@ def main() -> None:
192216 "results" : results ,
193217 }
194218
195- best_by_mode = {item ["mode" ]: item ["best_s" ] for item in results }
196- if "chunked" in best_by_mode and "fast" in best_by_mode :
197- summary ["speedup_fast_vs_chunked" ] = best_by_mode ["chunked" ] / best_by_mode ["fast" ]
219+ best_by_label = {item ["label" ]: item ["best_s" ] for item in results }
220+ if "chunked" in best_by_label and "fast" in best_by_label :
221+ summary ["speedup_fast_vs_chunked" ] = best_by_label ["chunked" ] / best_by_label ["fast" ]
222+ if "chunked" in best_by_label and "fast-naive" in best_by_label :
223+ summary ["speedup_fast_naive_vs_chunked" ] = best_by_label ["chunked" ] / best_by_label ["fast-naive" ]
224+ if "fast" in best_by_label and "fast-naive" in best_by_label :
225+ summary ["speedup_fast_vs_fast_naive" ] = best_by_label ["fast-naive" ] / best_by_label ["fast" ]
198226
199227 if args .json :
200228 print (json .dumps (summary , indent = 2 , sort_keys = True ))
@@ -208,10 +236,13 @@ def main() -> None:
208236 print (f" blocks A/B/out: { blocks_a } / { blocks_b } / { blocks_out } " )
209237 print (f" repeats: { args .repeats } " )
210238 print (f" fast block backend: { args .block_backend } " )
211- for item in results :
239+ display_order = ["chunked" , "fast-naive" , "fast" , "auto" ]
240+ ordered_results = sorted (results , key = lambda item : display_order .index (item ["label" ]) if item ["label" ] in display_order else len (display_order ))
241+
242+ for item in ordered_results :
212243 gflops_best = "-" if item ["gflops_best" ] is None else f"{ item ['gflops_best' ]:.3f} "
213244 print (
214- f"{ item ['mode ' ]:>7 } : "
245+ f"{ item ['label ' ]:>10 } : "
215246 f"best={ item ['best_s' ]:.6f} s "
216247 f"median={ item ['median_s' ]:.6f} s "
217248 f"gflops={ gflops_best } "
@@ -221,6 +252,10 @@ def main() -> None:
221252 )
222253 if "speedup_fast_vs_chunked" in summary :
223254 print (f"Speedup fast vs chunked: { summary ['speedup_fast_vs_chunked' ]:.3f} x" )
255+ if "speedup_fast_naive_vs_chunked" in summary :
256+ print (f"Speedup fast-naive vs chunked: { summary ['speedup_fast_naive_vs_chunked' ]:.3f} x" )
257+ if "speedup_fast_vs_fast_naive" in summary :
258+ print (f"Speedup fast vs fast-naive: { summary ['speedup_fast_vs_fast_naive' ]:.3f} x" )
224259
225260
226261if __name__ == "__main__" :
0 commit comments