Skip to content

Commit 13063c2

Browse files
committed
Use Accelerate for fast matmul blocks on macOS
1 parent f2678e8 commit 13063c2

3 files changed

Lines changed: 111 additions & 28 deletions

File tree

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ add_custom_command(
4343

4444
# ...and add it to the target
4545
Python_add_library(blosc2_ext MODULE blosc2_ext.c WITH_SOABI)
46+
target_sources(blosc2_ext PRIVATE src/blosc2/matmul_kernels.c)
47+
target_include_directories(blosc2_ext PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src/blosc2)
4648

4749
# We need to link against NumPy
4850
target_link_libraries(blosc2_ext PRIVATE Python::NumPy)
@@ -70,6 +72,9 @@ FetchContent_MakeAvailable(miniexpr)
7072

7173
# Link against miniexpr static library
7274
target_link_libraries(blosc2_ext PRIVATE miniexpr_static)
75+
if(APPLE)
76+
target_link_libraries(blosc2_ext PRIVATE "-framework Accelerate")
77+
endif()
7378

7479
target_compile_features(blosc2_ext PRIVATE c_std_11)
7580
if(WIN32 AND CMAKE_C_COMPILER_ID STREQUAL "Clang")

bench/ndarray/matmul_path_compare.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def set_path_mode(mode: str) -> bool:
5656

5757
def run_case(
5858
mode: str,
59+
block_backend: str,
5960
repeats: int,
6061
shape_a: tuple[int, ...],
6162
shape_b: tuple[int, ...],
@@ -73,8 +74,10 @@ def run_case(
7374
warnings.simplefilter("ignore", RuntimeWarning)
7475
expected = np.matmul(a_np, b_np)
7576
original_flag = set_path_mode(mode)
77+
original_block_backend = blosc2.blosc2_ext.get_matmul_block_backend()
7678
original_set_pref_matmul = blosc2.NDArray._set_pref_matmul
7779
selected_paths = []
80+
selected_block_backend = None
7881
times = []
7982
result = None
8083

@@ -83,7 +86,9 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
8386
return original_set_pref_matmul(self, inputs, fp_accuracy)
8487

8588
blosc2.NDArray._set_pref_matmul = wrapped_set_pref_matmul
89+
blosc2.blosc2_ext.set_matmul_block_backend(block_backend)
8690
try:
91+
selected_block_backend = blosc2.blosc2_ext.get_selected_matmul_block_backend()
8792
for _ in range(repeats):
8893
before = len(selected_paths)
8994
t0 = time.perf_counter()
@@ -97,6 +102,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
97102
finally:
98103
blosc2.NDArray._set_pref_matmul = original_set_pref_matmul
99104
linalg.try_miniexpr = original_flag
105+
blosc2.blosc2_ext.set_matmul_block_backend(original_block_backend)
100106

101107
if result is None:
102108
raise RuntimeError("matmul did not produce a result")
@@ -114,6 +120,8 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
114120
"gflops_best": expected_gflops(shape_a, shape_b, best),
115121
"gflops_median": expected_gflops(shape_a, shape_b, median),
116122
"correct": True,
123+
"configured_block_backend": block_backend,
124+
"selected_block_backend": selected_block_backend,
117125
"selected_paths": selected_paths,
118126
"selected_path": selected_paths[0] if selected_paths and len(set(selected_paths)) == 1 else "mixed",
119127
}
@@ -132,6 +140,12 @@ def main() -> None:
132140
parser.add_argument("--blocks-out", default="100,100", help="Comma-separated block shape for output.")
133141
parser.add_argument("--repeats", type=int, default=250)
134142
parser.add_argument("--modes", nargs="+", default=["chunked", "fast", "auto"], choices=["chunked", "fast", "auto"])
143+
parser.add_argument(
144+
"--block-backend",
145+
default="auto",
146+
choices=["auto", "naive", "accelerate"],
147+
help="Kernel backend for the fast matmul block path.",
148+
)
135149
parser.add_argument("--json", action="store_true", help="Emit full JSON instead of a compact text summary.")
136150
args = parser.parse_args()
137151

@@ -150,6 +164,7 @@ def main() -> None:
150164
results.append(
151165
run_case(
152166
mode,
167+
args.block_backend,
153168
args.repeats,
154169
shape_a,
155170
shape_b,
@@ -173,6 +188,7 @@ def main() -> None:
173188
"blocks_b": blocks_b,
174189
"chunks_out": chunks_out,
175190
"blocks_out": blocks_out,
191+
"block_backend": args.block_backend,
176192
"results": results,
177193
}
178194

@@ -184,36 +200,27 @@ def main() -> None:
184200
print(json.dumps(summary, indent=2, sort_keys=True))
185201
return
186202

187-
print(
188-
"case",
189-
json.dumps(
190-
{
191-
"shape_a": shape_a,
192-
"shape_b": shape_b,
193-
"dtype": str(dtype),
194-
"chunks_out": chunks_out,
195-
"blocks_out": blocks_out,
196-
},
197-
sort_keys=True,
198-
),
199-
)
203+
print("Matmul path comparison")
204+
print(f" A shape: {shape_a}")
205+
print(f" B shape: {shape_b}")
206+
print(f" dtype: {dtype}")
207+
print(f" chunks A/B/out: {chunks_a} / {chunks_b} / {chunks_out}")
208+
print(f" blocks A/B/out: {blocks_a} / {blocks_b} / {blocks_out}")
209+
print(f" repeats: {args.repeats}")
210+
print(f" fast block backend: {args.block_backend}")
200211
for item in results:
212+
gflops_best = "-" if item["gflops_best"] is None else f"{item['gflops_best']:.3f}"
201213
print(
202-
"result",
203-
json.dumps(
204-
{
205-
"mode": item["mode"],
206-
"best_s": round(item["best_s"], 6),
207-
"median_s": round(item["median_s"], 6),
208-
"gflops_best": None if item["gflops_best"] is None else round(item["gflops_best"], 3),
209-
"correct": item["correct"],
210-
"selected_path": item["selected_path"],
211-
},
212-
sort_keys=True,
213-
),
214+
f"{item['mode']:>7}: "
215+
f"best={item['best_s']:.6f}s "
216+
f"median={item['median_s']:.6f}s "
217+
f"gflops={gflops_best} "
218+
f"path={item['selected_path']} "
219+
f"block_backend={item['selected_block_backend']} "
220+
f"correct={item['correct']}"
214221
)
215222
if "speedup_fast_vs_chunked" in summary:
216-
print("speedup", json.dumps({"fast_vs_chunked": round(summary["speedup_fast_vs_chunked"], 3)}, sort_keys=True))
223+
print(f"Speedup fast vs chunked: {summary['speedup_fast_vs_chunked']:.3f}x")
217224

218225

219226
if __name__ == "__main__":

src/blosc2/blosc2_ext.pyx

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import dataclasses
1212
import ast
1313
import atexit
1414
import pathlib
15+
import time
16+
import warnings
1517

1618
import _ctypes
1719

@@ -62,6 +64,21 @@ ctypedef fused T:
6264
cdef extern from "<stdio.h>":
6365
int printf(const char *format, ...) nogil
6466

67+
cdef extern from "matmul_kernels.h":
68+
ctypedef enum b2_matmul_backend:
69+
B2_MATMUL_BACKEND_AUTO
70+
B2_MATMUL_BACKEND_NAIVE
71+
B2_MATMUL_BACKEND_ACCELERATE
72+
73+
int b2_has_accelerate() nogil
74+
void b2_set_matmul_backend(int backend) nogil
75+
int b2_get_matmul_backend() nogil
76+
int b2_get_selected_matmul_backend() nogil
77+
const char *b2_get_matmul_backend_name() nogil
78+
const char *b2_get_selected_matmul_backend_name() nogil
79+
int b2_gemm_accelerate_f32(const float *a, const float *b, float *c, int m, int k, int n) nogil
80+
int b2_gemm_accelerate_f64(const double *a, const double *b, double *c, int m, int k, int n) nogil
81+
6582
cdef extern from "blosc2.h":
6683

6784
ctypedef enum:
@@ -2384,6 +2401,7 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23842401
cdef int nchunk_ = nchunk
23852402
cdef int coord, batch, batch_, batches = 1
23862403
cdef int out_chunk_nrows, out_chunk_ncols, out_block_nrows, out_block_ncols
2404+
cdef int selected_backend = b2_get_selected_matmul_backend()
23872405

23882406
# batches = sum(strides[i]*elcoords[i])
23892407
for i in range(ndim - 2):
@@ -2487,9 +2505,43 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24872505
offset += coord * udata.el_strides[0][i]
24882506
if typecode == 0:
24892507
if typesize == 4:
2490-
rc = matmul_block_kernel[float](<float*>input_buffers[0] + offsetA, <float*>input_buffers[1] + offsetB, <float*>params_output + offset, p, q, r)
2508+
if selected_backend == B2_MATMUL_BACKEND_ACCELERATE:
2509+
rc = b2_gemm_accelerate_f32(
2510+
<float*>input_buffers[0] + offsetA,
2511+
<float*>input_buffers[1] + offsetB,
2512+
<float*>params_output + offset,
2513+
p,
2514+
q,
2515+
r,
2516+
)
2517+
else:
2518+
rc = matmul_block_kernel[float](
2519+
<float*>input_buffers[0] + offsetA,
2520+
<float*>input_buffers[1] + offsetB,
2521+
<float*>params_output + offset,
2522+
p,
2523+
q,
2524+
r,
2525+
)
24912526
else:
2492-
rc = matmul_block_kernel[double](<double*>input_buffers[0] + offsetA, <double*>input_buffers[1] + offsetB, <double*>params_output + offset, p, q, r)
2527+
if selected_backend == B2_MATMUL_BACKEND_ACCELERATE:
2528+
rc = b2_gemm_accelerate_f64(
2529+
<double*>input_buffers[0] + offsetA,
2530+
<double*>input_buffers[1] + offsetB,
2531+
<double*>params_output + offset,
2532+
p,
2533+
q,
2534+
r,
2535+
)
2536+
else:
2537+
rc = matmul_block_kernel[double](
2538+
<double*>input_buffers[0] + offsetA,
2539+
<double*>input_buffers[1] + offsetB,
2540+
<double*>params_output + offset,
2541+
p,
2542+
q,
2543+
r,
2544+
)
24932545
elif typecode == 1:
24942546
if typesize == 4:
24952547
rc = matmul_block_kernel[int32_t](<int32_t*>input_buffers[0] + offsetA, <int32_t*>input_buffers[1] + offsetB, <int32_t*>params_output + offset, p, q, r)
@@ -3999,3 +4051,22 @@ def squeeze(arr1: NDArray, axis_mask: list[bool]) -> blosc2.NDArray:
39994051
new_base = arr1 if arr1.base is None else arr1.base
40004052
return blosc2.NDArray(_schunk=PyCapsule_New(view.sc, <char *> "blosc2_schunk*", NULL),
40014053
_array=PyCapsule_New(view, <char *> "b2nd_array_t*", NULL), _base=new_base)
4054+
4055+
4056+
def set_matmul_block_backend(mode):
4057+
if mode == "auto":
4058+
b2_set_matmul_backend(B2_MATMUL_BACKEND_AUTO)
4059+
elif mode == "naive":
4060+
b2_set_matmul_backend(B2_MATMUL_BACKEND_NAIVE)
4061+
elif mode == "accelerate":
4062+
b2_set_matmul_backend(B2_MATMUL_BACKEND_ACCELERATE)
4063+
else:
4064+
raise ValueError("mode must be 'auto', 'naive', or 'accelerate'")
4065+
4066+
4067+
def get_matmul_block_backend():
4068+
return b2_get_matmul_backend_name().decode("utf-8")
4069+
4070+
4071+
def get_selected_matmul_block_backend():
4072+
return b2_get_selected_matmul_backend_name().decode("utf-8")

0 commit comments

Comments
 (0)