diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index e7c21a42b..cabcc429b 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -151,7 +151,7 @@ package VX_tcu_pkg; localparam TCU_WG_RC = TCU_RC; // WGMMA C accumulator starts at same base localparam TCU_WG_RA = 24; // WGMMA A register base (fixed f24..f27) localparam TCU_RA = 10; - localparam TCU_RB = (TCU_NRB == 4) ? 28 : 24; + localparam TCU_RB = (TCU_NRB == 4) ? 27 : 23; localparam TCU_UOPS = TCU_M_STEPS * TCU_N_STEPS * TCU_K_STEPS; diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 1a1467646..81e8bc178 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -602,14 +602,14 @@ struct wmma_context { if constexpr (FragB::NR == 8) { // frag_b: caller-saved registers (f24-f31) - register float fb0 __asm__("f24") = frag_b.data[0]; - register float fb1 __asm__("f25") = frag_b.data[1]; - register float fb2 __asm__("f26") = frag_b.data[2]; - register float fb3 __asm__("f27") = frag_b.data[3]; - register float fb4 __asm__("f28") = frag_b.data[4]; - register float fb5 __asm__("f29") = frag_b.data[5]; - register float fb6 __asm__("f30") = frag_b.data[6]; - register float fb7 __asm__("f31") = frag_b.data[7]; + register float fb0 __asm__("f23") = frag_b.data[0]; + register float fb1 __asm__("f24") = frag_b.data[1]; + register float fb2 __asm__("f25") = frag_b.data[2]; + register float fb3 __asm__("f26") = frag_b.data[3]; + register float fb4 __asm__("f27") = frag_b.data[4]; + register float fb5 __asm__("f28") = frag_b.data[5]; + register float fb6 __asm__("f29") = frag_b.data[6]; + register float fb7 __asm__("f30") = frag_b.data[7]; if constexpr (is_mx) { register float fma0 __asm__("f8") = frag_a.mx_meta[0]; @@ -621,7 +621,7 @@ struct wmma_context { register float fma2 __asm__("f20") = frag_a.mx_meta[2]; register float fma3 __asm__("f21") = frag_a.mx_meta[3]; register float fmb2 __asm__("f22") = frag_b.mx_meta[2]; - register float fmb3 __asm__("f23") = frag_b.mx_meta[3]; + register float fmb3 __asm__("f31") = frag_b.mx_meta[3]; __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x%[flags]" : "+f"(fd0), "+f"(fd1), "+f"(fd2), "+f"(fd3), "+f"(fd4), "+f"(fd5), "+f"(fd6), "+f"(fd7) @@ -651,10 +651,10 @@ struct wmma_context { static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); // frag_b: caller-saved registers (f28-f31) - register float fb0 __asm__("f28") = frag_b.data[0]; - register float fb1 __asm__("f29") = frag_b.data[1]; - register float fb2 __asm__("f30") = frag_b.data[2]; - register float fb3 __asm__("f31") = frag_b.data[3]; + register float fb0 __asm__("f27") = frag_b.data[0]; + register float fb1 __asm__("f28") = frag_b.data[1]; + register float fb2 __asm__("f29") = frag_b.data[2]; + register float fb3 __asm__("f30") = frag_b.data[3]; if constexpr (is_mx) { register float fma0 __asm__("f8") = frag_a.mx_meta[0]; @@ -666,7 +666,7 @@ struct wmma_context { register float fma2 __asm__("f20") = frag_a.mx_meta[2]; register float fma3 __asm__("f21") = frag_a.mx_meta[3]; register float fmb2 __asm__("f22") = frag_b.mx_meta[2]; - register float fmb3 __asm__("f23") = frag_b.mx_meta[3]; + register float fmb3 __asm__("f31") = frag_b.mx_meta[3]; __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x%[flags]" : "+f"(fd0), "+f"(fd1), "+f"(fd2), "+f"(fd3), "+f"(fd4), "+f"(fd5), "+f"(fd6), "+f"(fd7) diff --git a/sim/simx/decode.cpp b/sim/simx/decode.cpp index 8ac436e16..d1ae4ba82 100644 --- a/sim/simx/decode.cpp +++ b/sim/simx/decode.cpp @@ -1127,7 +1127,7 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) { static_assert(cfg::num_meta_loads <= 2, "sparse metadata decode assumes at most two loads"); constexpr uint32_t rc_base = 0, ra_base = 10; - constexpr uint32_t rb_base = (cfg::NRB == 4) ? 28 : 24; + constexpr uint32_t rb_base = (cfg::NRB == 4) ? 27 : 23; constexpr uint32_t sparse_k_steps = cfg::k_steps / 2; uint32_t fmt_d = rd, fmt_s = rs1; bool is_sparse = (rs2 & 1) != 0;