Skip to content

Commit 46b0842

Browse files
committed
Limit matmul fast path to floating-point 2D cases and fall back to chunked path
1 parent 2ae5387 commit 46b0842

2 files changed

Lines changed: 49 additions & 14 deletions

File tree

src/blosc2/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def _matmul_can_use_fast_path(
8888
if result_blocks[:-2] != result.blocks[:-2]:
8989
return False
9090

91-
if x1.dtype.kind not in ("i", "f"):
91+
if x1.dtype.kind != "f":
9292
return False
93-
if x2.dtype.kind not in ("i", "f"):
93+
if x2.dtype.kind != "f":
9494
return False
9595
return x1.dtype == x2.dtype
9696

tests/ndarray/test_linalg.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#######################################################################
77

88
import inspect
9+
import warnings
910
from itertools import permutations
1011

1112
import numpy as np
@@ -97,12 +98,40 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
9798
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
9899
try:
99100
_toggle_miniexpr(True)
100-
a = blosc2.ones(shape=(400, 400), dtype=np.int64, chunks=(200, 200), blocks=(100, 100))
101-
b = blosc2.full(shape=(400, 400), fill_value=2, dtype=np.int64, chunks=(200, 200), blocks=(100, 100))
101+
a = blosc2.ones(shape=(400, 400), dtype=np.float64, chunks=(200, 200), blocks=(100, 100))
102+
b = blosc2.full(
103+
shape=(400, 400), fill_value=2, dtype=np.float64, chunks=(200, 200), blocks=(100, 100)
104+
)
102105

103-
c = blosc2.matmul(a, b, chunks=(200, 200), blocks=(100, 100))
106+
with warnings.catch_warnings():
107+
warnings.simplefilter("ignore", RuntimeWarning)
108+
c = blosc2.matmul(a, b, chunks=(200, 200), blocks=(100, 100))
109+
expected = np.matmul(a[:], b[:])
104110

105111
assert calls == [((400, 400), (400, 400), (400, 400))]
112+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
113+
finally:
114+
_toggle_miniexpr(old_flag)
115+
116+
117+
def test_matmul_falls_back_for_integer_inputs(monkeypatch):
118+
old_flag = utils_mod.try_miniexpr
119+
calls = []
120+
original = blosc2.NDArray._set_pref_matmul
121+
122+
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
123+
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
124+
return original(self, inputs, fp_accuracy)
125+
126+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
127+
try:
128+
_toggle_miniexpr(True)
129+
a = blosc2.ones(shape=(200, 200), dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
130+
b = blosc2.full(shape=(200, 200), fill_value=2, dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
131+
132+
c = blosc2.matmul(a, b, chunks=(100, 100), blocks=(50, 50))
133+
134+
assert calls == []
106135
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
107136
finally:
108137
_toggle_miniexpr(old_flag)
@@ -120,15 +149,18 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
120149
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
121150
try:
122151
_toggle_miniexpr(True)
123-
a = blosc2.ones(shape=(2, 40, 40), dtype=np.int64, chunks=(1, 20, 20), blocks=(1, 10, 10))
152+
a = blosc2.ones(shape=(2, 40, 40), dtype=np.float64, chunks=(1, 20, 20), blocks=(1, 10, 10))
124153
b = blosc2.full(
125-
shape=(2, 40, 40), fill_value=2, dtype=np.int64, chunks=(1, 20, 20), blocks=(1, 10, 10)
154+
shape=(2, 40, 40), fill_value=2, dtype=np.float64, chunks=(1, 20, 20), blocks=(1, 10, 10)
126155
)
127156

128-
c = blosc2.matmul(a, b, chunks=(1, 20, 20), blocks=(1, 10, 10))
157+
with warnings.catch_warnings():
158+
warnings.simplefilter("ignore", RuntimeWarning)
159+
c = blosc2.matmul(a, b, chunks=(1, 20, 20), blocks=(1, 10, 10))
160+
expected = np.matmul(a[:], b[:])
129161

130162
assert calls == []
131-
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
163+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
132164
finally:
133165
_toggle_miniexpr(old_flag)
134166

@@ -142,13 +174,16 @@ def failing_set_pref_matmul(self, inputs, fp_accuracy):
142174
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", failing_set_pref_matmul)
143175
try:
144176
_toggle_miniexpr(True)
145-
a = blosc2.ones(shape=(200, 200), dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
146-
b = blosc2.full(shape=(200, 200), fill_value=2, dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
177+
a = blosc2.ones(shape=(200, 200), dtype=np.float64, chunks=(100, 100), blocks=(50, 50))
178+
b = blosc2.full(shape=(200, 200), fill_value=2, dtype=np.float64, chunks=(100, 100), blocks=(50, 50))
147179

148-
with pytest.warns(RuntimeWarning, match="falling back to chunked path"):
149-
c = blosc2.matmul(a, b, chunks=(100, 100), blocks=(50, 50))
180+
with warnings.catch_warnings():
181+
warnings.filterwarnings("ignore", message=".*encountered in matmul", category=RuntimeWarning)
182+
with pytest.warns(RuntimeWarning, match="falling back to chunked path"):
183+
c = blosc2.matmul(a, b, chunks=(100, 100), blocks=(50, 50))
184+
expected = np.matmul(a[:], b[:])
150185

151-
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
186+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
152187
finally:
153188
_toggle_miniexpr(old_flag)
154189

0 commit comments

Comments
 (0)