Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hw/rtl/tcu/VX_tcu_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ package VX_tcu_pkg;
localparam TCU_WG_RC = TCU_RC; // WGMMA C accumulator starts at same base
localparam TCU_WG_RA = 24; // WGMMA A register base (fixed f24..f27)
localparam TCU_RA = 10;
localparam TCU_RB = (TCU_NRB == 4) ? 28 : 24;
localparam TCU_RB = (TCU_NRB == 4) ? 27 : 23;

localparam TCU_UOPS = TCU_M_STEPS * TCU_N_STEPS * TCU_K_STEPS;

Expand Down
28 changes: 14 additions & 14 deletions kernel/include/vx_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,14 @@ struct wmma_context {
if constexpr (FragB::NR == 8) {

// frag_b: caller-saved registers (f24-f31)
register float fb0 __asm__("f24") = frag_b.data[0];
register float fb1 __asm__("f25") = frag_b.data[1];
register float fb2 __asm__("f26") = frag_b.data[2];
register float fb3 __asm__("f27") = frag_b.data[3];
register float fb4 __asm__("f28") = frag_b.data[4];
register float fb5 __asm__("f29") = frag_b.data[5];
register float fb6 __asm__("f30") = frag_b.data[6];
register float fb7 __asm__("f31") = frag_b.data[7];
register float fb0 __asm__("f23") = frag_b.data[0];
register float fb1 __asm__("f24") = frag_b.data[1];
register float fb2 __asm__("f25") = frag_b.data[2];
register float fb3 __asm__("f26") = frag_b.data[3];
register float fb4 __asm__("f27") = frag_b.data[4];
register float fb5 __asm__("f28") = frag_b.data[5];
register float fb6 __asm__("f29") = frag_b.data[6];
register float fb7 __asm__("f30") = frag_b.data[7];

if constexpr (is_mx) {
register float fma0 __asm__("f8") = frag_a.mx_meta[0];
Expand All @@ -621,7 +621,7 @@ struct wmma_context {
register float fma2 __asm__("f20") = frag_a.mx_meta[2];
register float fma3 __asm__("f21") = frag_a.mx_meta[3];
register float fmb2 __asm__("f22") = frag_b.mx_meta[2];
register float fmb3 __asm__("f23") = frag_b.mx_meta[3];
register float fmb3 __asm__("f31") = frag_b.mx_meta[3];

__asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x%[flags]"
: "+f"(fd0), "+f"(fd1), "+f"(fd2), "+f"(fd3), "+f"(fd4), "+f"(fd5), "+f"(fd6), "+f"(fd7)
Expand Down Expand Up @@ -651,10 +651,10 @@ struct wmma_context {
static_assert(FragB::NR == 4, "Unsupported number of registers for FragB");

// frag_b: caller-saved registers (f28-f31)
register float fb0 __asm__("f28") = frag_b.data[0];
register float fb1 __asm__("f29") = frag_b.data[1];
register float fb2 __asm__("f30") = frag_b.data[2];
register float fb3 __asm__("f31") = frag_b.data[3];
register float fb0 __asm__("f27") = frag_b.data[0];
register float fb1 __asm__("f28") = frag_b.data[1];
register float fb2 __asm__("f29") = frag_b.data[2];
register float fb3 __asm__("f30") = frag_b.data[3];

if constexpr (is_mx) {
register float fma0 __asm__("f8") = frag_a.mx_meta[0];
Expand All @@ -666,7 +666,7 @@ struct wmma_context {
register float fma2 __asm__("f20") = frag_a.mx_meta[2];
register float fma3 __asm__("f21") = frag_a.mx_meta[3];
register float fmb2 __asm__("f22") = frag_b.mx_meta[2];
register float fmb3 __asm__("f23") = frag_b.mx_meta[3];
register float fmb3 __asm__("f31") = frag_b.mx_meta[3];

__asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x%[flags]"
: "+f"(fd0), "+f"(fd1), "+f"(fd2), "+f"(fd3), "+f"(fd4), "+f"(fd5), "+f"(fd6), "+f"(fd7)
Expand Down
2 changes: 1 addition & 1 deletion sim/simx/decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) {
static_assert(cfg::num_meta_loads <= 2, "sparse metadata decode assumes at most two loads");

constexpr uint32_t rc_base = 0, ra_base = 10;
constexpr uint32_t rb_base = (cfg::NRB == 4) ? 28 : 24;
constexpr uint32_t rb_base = (cfg::NRB == 4) ? 27 : 23;
constexpr uint32_t sparse_k_steps = cfg::k_steps / 2;
uint32_t fmt_d = rd, fmt_s = rs1;
bool is_sparse = (rs2 & 1) != 0;
Expand Down