@@ -2182,12 +2182,13 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
21822182 cdef int blocknitems[2 ]
21832183 cdef int startA, startB, expected_blocknitems
21842184 cdef blosc2_context* dctx
2185- cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2185+ cdef int i, j, block_i, block_j, chunk_i, chunk_j, ncols, block_ncols, Bblock_ncols, Bncols, Ablock_ncols, Ancols
21862186 cdef int nchunkA = 0 , nchunkB = 0 , nblockA = 0 , nblockB = 0 , offsetA = 0 , offsetB = 0 , offset = 0
21872187 out_arr = udata.array
21882188 cdef int ndim = out_arr.ndim
21892189 cdef int nchunk_ = nchunk
21902190 cdef int coord, batch, batch_, batches = 1
2191+ cdef int out_chunk_nrows, out_chunk_ncols, out_block_nrows, out_block_ncols
21912192
21922193 # batches = sum(strides[i]*elcoords[i])
21932194 for i in range (ndim - 2 ):
@@ -2201,12 +2202,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22012202 nchunkB += coord * udata.chunks_strides[2 ][i]
22022203
22032204 ncols = udata.chunks_strides[0 ][ndim - 2 ]
2205+ Ancols = udata.chunks_strides[1 ][ndim - 2 ]
22042206 Bncols = udata.chunks_strides[2 ][ndim - 2 ]
2205-
2206- i = nchunk_ // ncols # ncols * i + j
2207- j = nchunk_ % ncols
2208- chunk_startA = nchunkA + i * ncols
2209- chunk_startB = nchunkB + j
2207+ out_chunk_nrows = out_arr.chunkshape[ndim - 2 ]
2208+ out_chunk_ncols = out_arr.chunkshape[ndim - 1 ]
22102209
22112210 # nblock = sum(strides[i]*blockcoords[i])
22122211 cdef int nblock_ = nblock
@@ -2217,18 +2216,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22172216 nblockB += coord * udata.blocks_strides[2 ][i]
22182217
22192218 block_ncols = udata.blocks_strides[0 ][ndim - 2 ]
2219+ Ablock_ncols = udata.blocks_strides[1 ][ndim - 2 ]
22202220 Bblock_ncols = udata.blocks_strides[2 ][ndim - 2 ]
2221-
2222- block_i = nblock_ // block_ncols
2223- block_j = nblock_ % block_ncols
2224- block_startA = nblockA + block_i * block_ncols
2225- block_startB = nblockB + block_j
2221+ out_block_nrows = out_arr.blockshape[ndim - 2 ]
2222+ out_block_ncols = out_arr.blockshape[ndim - 1 ]
22262223
22272224 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22282225
22292226 first_run = True
2230- nchunkA = chunk_startA
2231- nchunkB = chunk_startB
22322227 while True : # chunk loop
22332228 for i in range (2 ):
22342229 chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2244,16 +2239,28 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22442239 if i == 0 :
22452240 q = ndarr.blockshape[ndim - 1 ]
22462241 p = ndarr.blockshape[ndim - 2 ]
2242+ # nchunk_ = chunks_in_row * chunk_row + chunk_col
2243+ # convert from chunk_idx to element idx chunk_i (row)
2244+ chunk_i = nchunk_ // ncols * out_chunk_nrows
2245+ chunk_startA = nchunkA + chunk_i // ndarr.chunkshape[ndim - 2 ] * Ancols
2246+ nchunkA = chunk_startA
2247+ # nblock_ = blocks_in_chunkrow * block_row + block_col
2248+ # convert from block_idx to element idx block_i (row)
2249+ block_i = nblock_ // block_ncols * out_block_nrows
2250+ block_startA = nblockA + block_i // p * Ablock_ncols
22472251 else : # i = 1
22482252 r = ndarr.blockshape[ndim - 1 ]
2253+ # convert from chunk_idx to element idx chunk_j (col)
2254+ chunk_j = nchunk_ % ncols * out_chunk_ncols
2255+ chunk_startB = nchunkB + chunk_j // ndarr.chunkshape[ndim - 1 ]
2256+ nchunkB = chunk_startB
2257+ # convert from block_idx to element idx block_j (col)
2258+ block_j = nblock_ % block_ncols * out_block_ncols
2259+ block_startB = nblockB + block_j // r
22492260 input_buffers[i] = malloc(block_nbytes[i])
22502261 if input_buffers[i] == NULL :
22512262 raise MemoryError (" miniexpr: cannot allocate input block buffer" )
22522263 blocknitems[i] = block_nbytes[i] // < int > ndarr.sc.typesize
2253- if i == 0 :
2254- expected_blocknitems = blocknitems[i]
2255- elif blocknitems[i] != expected_blocknitems:
2256- raise ValueError (" miniexpr: inconsistent block element counts across inputs" )
22572264
22582265 first_run = False
22592266 nblockA = block_startA
@@ -2297,11 +2304,11 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22972304 batch += 1
22982305 nblockA += 1
22992306 nblockB += Bblock_ncols
2300- if (nblockA % block_ncols == 0 ):
2307+ if (nblockA % Ablock_ncols == 0 ):
23012308 break
23022309 nchunkA += 1
23032310 nchunkB += Bncols
2304- if (nchunkA % ncols == 0 ):
2311+ if (nchunkA % Ancols == 0 ):
23052312 break
23062313
23072314
@@ -3280,7 +3287,7 @@ cdef class NDArray:
32803287 cstrides = bstrides = estrides = 1
32813288 for idx in range (2 , self .array.ndim + 1 ):
32823289 i = inp.ndim - idx
3283- if inp.shape[i + 1 ] == 1 or i < 0 :
3290+ if ( inp.shape[i + 1 ] == 1 and i < inp.ndim - 3 ) or i < 0 :
32843291 udata.chunks_strides[j][i] = 0
32853292 udata.blocks_strides[j][i] = 0
32863293 udata.el_strides[j][i] = 0
0 commit comments