Skip to content

Commit 2a57a30

Browse files
clee1994Google-ML-Automation
authored andcommitted
Fixing kernel to accomodate compiler change.
PiperOrigin-RevId: 895535377
1 parent dbc4d58 commit 2a57a30

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxtext/kernels/gather_reduce_sc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,14 @@ def perform_add(
460460
weights_evens = tpu.unpack_subelements(
461461
_F32[16],
462462
raw_weights,
463-
0,
463+
1,
464464
ir.Attribute.parse("#tpu.pack_format<interleaved>"),
465465
)
466466
# part=0 corresponds to sc_tpu.unpackf result 0 (Odd indices)
467467
weights_odds = tpu.unpack_subelements(
468468
_F32[16],
469469
raw_weights,
470-
1,
470+
0,
471471
ir.Attribute.parse("#tpu.pack_format<interleaved>"),
472472
)
473473

@@ -541,13 +541,13 @@ def get_row_val(row_idx):
541541
vec_f32_evens = tpu.unpack_subelements(
542542
_F32[16],
543543
vec_bf16_2x16,
544-
0,
544+
1,
545545
ir.Attribute.parse("#tpu.pack_format<interleaved>"),
546546
)
547547
vec_f32_odds = tpu.unpack_subelements(
548548
_F32[16],
549549
vec_bf16_2x16,
550-
1,
550+
0,
551551
ir.Attribute.parse("#tpu.pack_format<interleaved>"),
552552
)
553553
parity_of_row = vector.extract(
@@ -620,7 +620,7 @@ def get_row_val(row_idx):
620620

621621
packed = tpu.pack_subelements(
622622
_BF16[2, vreg_size],
623-
[row8, row0],
623+
[row0, row8],
624624
[0, 1],
625625
ir.Attribute.parse("#tpu.pack_format<interleaved>"),
626626
)

0 commit comments

Comments
 (0)