Skip to content

Commit 637ce8d

Browse files
committed
Extended allowable cases
1 parent 1d384cf commit 637ce8d

2 files changed

Lines changed: 44 additions & 38 deletions

File tree

src/blosc2/blosc2_ext.pyx

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,12 +2182,13 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
21822182
cdef int blocknitems[2]
21832183
cdef int startA, startB, expected_blocknitems
21842184
cdef blosc2_context* dctx
2185-
cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2185+
cdef int i, j, block_i, block_j, chunk_i, chunk_j, ncols, block_ncols, Bblock_ncols, Bncols, Ablock_ncols, Ancols
21862186
cdef int nchunkA = 0, nchunkB = 0, nblockA = 0, nblockB = 0, offsetA = 0, offsetB = 0, offset = 0
21872187
out_arr = udata.array
21882188
cdef int ndim = out_arr.ndim
21892189
cdef int nchunk_ = nchunk
21902190
cdef int coord, batch, batch_, batches = 1
2191+
cdef int out_chunk_nrows, out_chunk_ncols, out_block_nrows, out_block_ncols
21912192

21922193
# batches = sum(strides[i]*elcoords[i])
21932194
for i in range(ndim - 2):
@@ -2201,12 +2202,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22012202
nchunkB += coord * udata.chunks_strides[2][i]
22022203

22032204
ncols = udata.chunks_strides[0][ndim - 2]
2205+
Ancols = udata.chunks_strides[1][ndim - 2]
22042206
Bncols = udata.chunks_strides[2][ndim - 2]
2205-
2206-
i = nchunk_ // ncols # ncols * i + j
2207-
j = nchunk_ % ncols
2208-
chunk_startA = nchunkA + i * ncols
2209-
chunk_startB = nchunkB + j
2207+
out_chunk_nrows = out_arr.chunkshape[ndim - 2]
2208+
out_chunk_ncols = out_arr.chunkshape[ndim - 1]
22102209

22112210
# nblock = sum(strides[i]*blockcoords[i])
22122211
cdef int nblock_ = nblock
@@ -2217,18 +2216,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22172216
nblockB += coord * udata.blocks_strides[2][i]
22182217

22192218
block_ncols = udata.blocks_strides[0][ndim - 2]
2219+
Ablock_ncols = udata.blocks_strides[1][ndim - 2]
22202220
Bblock_ncols = udata.blocks_strides[2][ndim - 2]
2221-
2222-
block_i = nblock_ // block_ncols
2223-
block_j = nblock_ % block_ncols
2224-
block_startA = nblockA + block_i * block_ncols
2225-
block_startB = nblockB + block_j
2221+
out_block_nrows = out_arr.blockshape[ndim - 2]
2222+
out_block_ncols = out_arr.blockshape[ndim - 1]
22262223

22272224
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22282225

22292226
first_run = True
2230-
nchunkA = chunk_startA
2231-
nchunkB = chunk_startB
22322227
while True: # chunk loop
22332228
for i in range(2):
22342229
chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2244,16 +2239,28 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22442239
if i == 0:
22452240
q = ndarr.blockshape[ndim - 1]
22462241
p = ndarr.blockshape[ndim - 2]
2242+
# nchunk_ = chunks_in_row * chunk_row + chunk_col
2243+
# convert from chunk_idx to element idx chunk_i (row)
2244+
chunk_i = nchunk_ // ncols * out_chunk_nrows
2245+
chunk_startA = nchunkA + chunk_i // ndarr.chunkshape[ndim - 2] * Ancols
2246+
nchunkA = chunk_startA
2247+
# nblock_ = blocks_in_chunkrow * block_row + block_col
2248+
# convert from block_idx to element idx block_i (row)
2249+
block_i = nblock_ // block_ncols * out_block_nrows
2250+
block_startA = nblockA + block_i // p * Ablock_ncols
22472251
else: # i = 1
22482252
r = ndarr.blockshape[ndim - 1]
2253+
# convert from chunk_idx to element idx chunk_j (col)
2254+
chunk_j = nchunk_ % ncols * out_chunk_ncols
2255+
chunk_startB = nchunkB + chunk_j // ndarr.chunkshape[ndim - 1]
2256+
nchunkB = chunk_startB
2257+
# convert from block_idx to element idx block_j (col)
2258+
block_j = nblock_ % block_ncols * out_block_ncols
2259+
block_startB = nblockB + block_j // r
22492260
input_buffers[i] = malloc(block_nbytes[i])
22502261
if input_buffers[i] == NULL:
22512262
raise MemoryError("miniexpr: cannot allocate input block buffer")
22522263
blocknitems[i] = block_nbytes[i] // <int> ndarr.sc.typesize
2253-
if i == 0:
2254-
expected_blocknitems = blocknitems[i]
2255-
elif blocknitems[i] != expected_blocknitems:
2256-
raise ValueError("miniexpr: inconsistent block element counts across inputs")
22572264

22582265
first_run = False
22592266
nblockA = block_startA
@@ -2297,11 +2304,11 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22972304
batch += 1
22982305
nblockA += 1
22992306
nblockB += Bblock_ncols
2300-
if (nblockA % block_ncols == 0):
2307+
if (nblockA % Ablock_ncols == 0):
23012308
break
23022309
nchunkA += 1
23032310
nchunkB += Bncols
2304-
if (nchunkA % ncols == 0):
2311+
if (nchunkA % Ancols == 0):
23052312
break
23062313

23072314

@@ -3280,7 +3287,7 @@ cdef class NDArray:
32803287
cstrides = bstrides = estrides = 1
32813288
for idx in range(2, self.array.ndim + 1):
32823289
i = inp.ndim - idx
3283-
if inp.shape[i + 1] == 1 or i < 0:
3290+
if (inp.shape[i + 1] == 1 and i < inp.ndim - 3) or i < 0:
32843291
udata.chunks_strides[j][i] = 0
32853292
udata.blocks_strides[j][i] = 0
32863293
udata.el_strides[j][i] = 0

src/blosc2/linalg.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,13 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
125125
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
126126
use_miniexpr = False
127127

128+
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
128129
# Just force same chunk/block shapes
129-
same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
130-
same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
131-
same_shape = all(op.shape == result.shape for op in (x1, x2))
132-
133-
use_miniexpr &= same_blocks & same_chunks & same_shape
130+
# same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
131+
# same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
132+
# same_shape = all(op.shape == result.shape for op in (x1, x2))
134133

135-
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
134+
# use_miniexpr &= same_blocks & same_chunks & same_shape
136135
# Two easier cases are presented below
137136
# Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
138137
# x1.chunks[-2] % result.blocks[-2] == 0
@@ -146,18 +145,18 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
146145
# (M, K) x (K, N) = (M, N)
147146
# so can load block-by-block for inputs and calculate block of output
148147
# Also need to avoid loading across chunk boundaries
149-
# chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
150-
# chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
151-
# chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
152-
# same_blocks = x2.blocks[-2] == x1.blocks[-1]
153-
# same_blocks &= x2.blocks[-1] == result.blocks[-1]
154-
# same_blocks &= result.blocks[-2] == x1.blocks[-2]
155-
# try:
156-
# result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
157-
# if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
158-
# use_miniexpr = False
159-
# except ValueError:
160-
# use_miniexpr = False
148+
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
149+
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
150+
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
151+
same_blocks = x2.blocks[-2] == x1.blocks[-1]
152+
same_blocks &= x2.blocks[-1] == result.blocks[-1]
153+
same_blocks &= result.blocks[-2] == x1.blocks[-2]
154+
try:
155+
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
156+
if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
157+
use_miniexpr = False
158+
except ValueError:
159+
use_miniexpr = False
161160

162161
use_miniexpr &= x1.dtype.kind in ("i", "f")
163162
use_miniexpr &= x2.dtype.kind in ("i", "f")

0 commit comments

Comments
 (0)