File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -460,14 +460,14 @@ def perform_add(
460460 weights_evens = tpu .unpack_subelements (
461461 _F32 [16 ],
462462 raw_weights ,
463- 0 ,
463+ 1 ,
464464 ir .Attribute .parse ("#tpu.pack_format<interleaved>" ),
465465 )
466466 # part=0 corresponds to sc_tpu.unpackf result 0 (Odd indices)
467467 weights_odds = tpu .unpack_subelements (
468468 _F32 [16 ],
469469 raw_weights ,
470- 1 ,
470+ 0 ,
471471 ir .Attribute .parse ("#tpu.pack_format<interleaved>" ),
472472 )
473473
@@ -541,13 +541,13 @@ def get_row_val(row_idx):
541541 vec_f32_evens = tpu .unpack_subelements (
542542 _F32 [16 ],
543543 vec_bf16_2x16 ,
544- 0 ,
544+ 1 ,
545545 ir .Attribute .parse ("#tpu.pack_format<interleaved>" ),
546546 )
547547 vec_f32_odds = tpu .unpack_subelements (
548548 _F32 [16 ],
549549 vec_bf16_2x16 ,
550- 1 ,
550+ 0 ,
551551 ir .Attribute .parse ("#tpu.pack_format<interleaved>" ),
552552 )
553553 parity_of_row = vector .extract (
@@ -620,7 +620,7 @@ def get_row_val(row_idx):
620620
621621 packed = tpu .pack_subelements (
622622 _BF16 [2 , vreg_size ],
623- [row8 , row0 ],
623+ [row0 , row8 ],
624624 [0 , 1 ],
625625 ir .Attribute .parse ("#tpu.pack_format<interleaved>" ),
626626 )
You can’t perform that action at this time.
0 commit comments