Skip to content

GEMM partition_N > 1 produces incorrect results (affects Llama final vocab projection) #83

@tonyjie

Description

@tonyjie

GEMM partition_N > 1 produces incorrect results (affects Llama final vocab projection)

Summary

The GEMM operator produces incorrect results when partition_N > 1. Only the first partition (C_0) computes correctly; partitions C_1 through C_{N-1} produce wrong output. This directly affects the Llama 3.2 1B model, which uses partition_N=4 for its final vocab projection (128256 outputs).

There are two separate bugs:

  1. forward() returns wrong output shape when partition_N > 1 with static weights
  2. Runlist entries 1+ produce wrong results when multiple entries share the same XRT kernel handle

Affected Code

  • iron/operators/gemm/op.py_partition_B(), forward(), _execute_aie_operation()
  • iron/applications/llama_3.2_1b/src/model_with_json.py line 186 — partition_N=4

Reproduction

cd /path/to/IRON
source ironenv/bin/activate && source /opt/xilinx/xrt/setup.sh
python repro_partition_n_bug.py
repro_partition_n_bug.py (click to expand)
#!/usr/bin/env python3
"""Reproduce partition_N=4 GEMM bugs in IRON.

This script demonstrates two bugs in the GEMM operator when using partition_N > 1,
which is used by the Llama 3.2 1B final vocab projection (model_with_json.py line 186).

Usage:
    cd /path/to/IRON
    source ironenv/bin/activate && source /opt/xilinx/xrt/setup.sh
    python repro_partition_n_bug.py

Requirements: IRON with mlir_aie and XRT installed. No test framework changes needed.
"""

import torch
import numpy as np
from pathlib import Path
from ml_dtypes import bfloat16

from iron.operators.gemm.op import AIEGEMM
from iron.operators.gemm.reference import generate_golden_reference
from iron.common import AIEContext
from iron.common.utils import torch_to_numpy


def check_partition(output_2d, ref_2d, label):
    """Compare NPU output vs CPU reference (both as float32 2D arrays)."""
    out = output_2d.reshape(-1)
    ref = ref_2d.reshape(-1)
    n = min(len(out), len(ref))
    corr = float(np.corrcoef(out[:n], ref[:n])[0, 1])
    max_err = float(np.max(np.abs(out[:n] - ref[:n])))
    mean_err = float(np.mean(np.abs(out[:n] - ref[:n])))
    status = "PASS" if corr > 0.99 else "FAIL"
    print(f"  {label}: corr={corr:.5f}, max_err={max_err:.1f}, mean_err={mean_err:.2f}  [{status}]")
    return corr


# ---------- Configuration (matches Llama model_with_json.py lines 178-196) ----------
M, K, N = 2048, 2048, 128256
PARTITION_N = 4
N_PER_PART = N // PARTITION_N  # 32064
BUILD_DIR = Path("build_repro").resolve()

print("=" * 70)
print("IRON GEMM partition_N Bug Reproduction")
print("=" * 70)
print(f"Problem: M={M}, K={K}, N={N}, partition_N={PARTITION_N}")
print(f"Matches: Llama 3.2 1B final vocab projection (model_with_json.py)")
print(f"Build dir: {BUILD_DIR}")
print()

ref = generate_golden_reference(M=M, K=K, N=N, b_col_maj=True, partition_N=PARTITION_N)

# ======================== BUG 1: forward() returns wrong shape ========================
print("=" * 70)
print("BUG 1: forward() returns wrong output shape with partition_N > 1")
print("=" * 70)

ctx1 = AIEContext()
ctx1.build_dir = BUILD_DIR

op1 = AIEGEMM(
    M=M, K=K, N=N,
    tile_m=64, tile_k=64, tile_n=64,
    num_aie_columns=8,
    prio_accuracy=False, emulate_bf16_mmul_with_bfp16=True,
    b_col_maj=True, use_static_weight=True, partition_N=PARTITION_N,
    context=ctx1,
)

full_B = torch.cat(ref["input_b"], dim=0)  # (N, K) in b_col_maj format
op1.weight = full_B.T  # Model does: op.weight = out_head.T

ctx1.compile_all()
ctx1.prepare_runtime()

A_input = torch.randn(1, M, K, dtype=torch.bfloat16) * 4
result = op1.forward(A_input)

print(f"  Expected output shape: (1, {M}, {N})")
print(f"  Actual output shape:   {tuple(result.shape)}")
print()
if result.shape[-1] != N:
    print(f"  BUG CONFIRMED: forward() returns {result.shape[-1]} columns instead of {N}.")
    print(f"  Root cause: _partition_B() (op.py) overwrites self.static_weight_shape")
    print(f"  to single-partition size ({op1.N}, {K}), then forward() divides")
    print(f"  by partition_N again, yielding N_part = {op1.N // PARTITION_N}.")
    print()
    print(f"  The Llama model calls out_head_prefill(x) which hits this path.")
    print(f"  Logits shape is (batch, seq_len, {result.shape[-1]}) instead of")
    print(f"  (batch, seq_len, {N}), silently truncating the vocabulary.")
else:
    print("  Shape is correct. If running on unpatched code, expect shape")
    print(f"  (1, {M}, {op1.N}) instead -- see bug description.")
print()

# ======================== BUG 2: Only partition 0 produces correct results ========================
print("=" * 70)
print("BUG 2: Only C_0 is correct when partition_N > 1 in single context")
print("=" * 70)
print()
print("Reading individual partition buffers directly (bypassing forward())...")

ctx2 = AIEContext()
ctx2.build_dir = BUILD_DIR

op2 = AIEGEMM(
    M=M, K=K, N=N,
    tile_m=64, tile_k=64, tile_n=64,
    num_aie_columns=8,
    prio_accuracy=False, emulate_bf16_mmul_with_bfp16=True,
    b_col_maj=True, use_static_weight=True, partition_N=PARTITION_N,
    context=ctx2,
)
full_B = torch.cat(ref["input_b"], dim=0)
op2.weight = full_B.T

ctx2.compile_all()
ctx2.prepare_runtime()
op2.write_buffer("A", torch_to_numpy(ref["input"]))
op2.run_runlist()

print(f"  N_per_partition={N_PER_PART}, N_padded={op2.N}, padding={op2.N - N_PER_PART}")
print()

# Read each C_i with correct 2D shape (accounting for N padding)
for i in range(PARTITION_N):
    out_2d = np.array(op2.read_buffer(f"C_{i}", (op2.M, op2.N)), dtype=np.float32)
    out_valid = out_2d[:M, :N_PER_PART]
    ref_valid = torch_to_numpy(ref["output"][i]).reshape(M, N_PER_PART).astype(np.float32)
    check_partition(out_valid, ref_valid, f"C_{i} (vocab {i*N_PER_PART}-{(i+1)*N_PER_PART-1})")
print()

# ======================== CONTROL: Standalone partitions all work ========================
print("=" * 70)
print("CONTROL: Each partition works correctly as standalone GEMM (partition_N=1)")
print("=" * 70)
print()

for i in range(PARTITION_N):
    ctx_i = AIEContext()
    ctx_i.build_dir = BUILD_DIR

    op_i = AIEGEMM(
        M=M, K=K, N=N_PER_PART,
        tile_m=64, tile_k=64, tile_n=64,
        num_aie_columns=8,
        prio_accuracy=False, emulate_bf16_mmul_with_bfp16=True,
        b_col_maj=True, use_static_weight=True, partition_N=1,
        context=ctx_i,
    )
    op_i.weight = ref["input_b"][i].T  # Single partition weight

    ctx_i.compile_all()
    ctx_i.prepare_runtime()
    op_i.write_buffer("A", torch_to_numpy(ref["input"]))
    op_i.run_runlist()

    out_2d = np.array(op_i.read_buffer("C_0", (op_i.M, op_i.N)), dtype=np.float32)
    out_valid = out_2d[:M, :N_PER_PART]
    ref_valid = torch_to_numpy(ref["output"][i]).reshape(M, N_PER_PART).astype(np.float32)
    check_partition(out_valid, ref_valid, f"Standalone partition {i}")

print()
print("=" * 70)
print("CONCLUSION")
print("=" * 70)
print("""
Bug 1 (forward() shape): _partition_B() overwrites self.static_weight_shape
  to single-partition size. forward() reads N from this corrupted shape and
  divides by partition_N again, returning (M, N_padded_per_part) instead of
  (M, N_full). The Llama model's final vocab GEMM silently operates on a
  truncated vocabulary.

Bug 2 (partition correctness): When partition_N > 1, all 4 runlist entries
  share the same XRT kernel handle and instruction binary (insts.bin). The
  NPU's DMA descriptors bind to buffer addresses from the first invocation
  and are not re-resolved for subsequent entries. Only C_0 (first partition)
  produces correct results; C_1-C_3 read wrong buffer data.

  Each partition works perfectly when run as a standalone GEMM operator with
  its own AIEContext (separate XRT kernel handle + instruction binary).

Impact: The Llama 3.2 1B model's final vocab projection (128256 outputs,
  partition_N=4) produces correct logits only for vocab indices 0-32063.
  The model generates coherent text because common tokens have low indices
  and argmax is noise-tolerant, but output quality is degraded.

No existing test covers partition_N > 1. The Llama app test (test.py) only
  checks returncode == 0 with no output correctness validation.
""")

Expected Output

BUG 1: forward() returns wrong output shape with partition_N > 1
  Expected output shape: (1, 2048, 128256)
  Actual output shape:   (1, 2048, 32256)     <-- should be 128256

BUG 2: Only C_0 is correct when partition_N > 1 in single context
  C_0 (vocab 0-32063):       corr=0.99994  [PASS]
  C_1 (vocab 32064-64127):   corr=0.74833  [FAIL]
  C_2 (vocab 64128-96191):   corr=0.74844  [FAIL]
  C_3 (vocab 96192-128255):  corr=0.74172  [FAIL]

CONTROL: Each partition works correctly as standalone GEMM (partition_N=1)
  Standalone partition 0:    corr=0.99994  [PASS]
  Standalone partition 1:    corr=0.99994  [PASS]
  Standalone partition 2:    corr=0.99994  [PASS]
  Standalone partition 3:    corr=0.99994  [PASS]

Bug 1: forward() returns wrong output shape

Root Cause

_partition_B() (op.py line 383) overwrites self.static_weight_shape to the single-partition size:

def _partition_B(self, B):
    ...
    self.static_weight_shape = B_parts[0].shape  # <-- overwrites to (32256, 2048)

Later, forward() reads N from this corrupted shape:

def forward(self, A, B=None):
    B_shape = B.shape if B is not None else self.static_weight_shape  # (32256, 2048)
    K2, N = self._get_B_dims(B_shape)  # N = 32256 (should be 128256)
    N_part = N // self.partition_N     # 32256 / 4 = 8064 (should be 32064)

The output shape becomes (M, 32256) instead of (M, 128256).

Impact on Llama

The model calls self.out_head_prefill(x) which returns logits of shape (batch, seq_len, 32256) instead of (batch, seq_len, 128256). The model then does argmax(logits[:, -1, :]) over only 32256 values -- a scrambled mix of 4 partition results reassembled into the wrong column positions.

Fix

Three changes in op.py:

  1. Initialize static_weight_shape with full dimensions in the correct layout
  2. Remove the self.static_weight_shape = B_parts[0].shape overwrite from _partition_B()
  3. Fix the applicability check N <= self.N to N <= self.N * self.partition_N, and fix _execute_aie_operation() to use self.K, self.N directly for static weights

Bug 2: Only first partition produces correct results

Root Cause

When partition_N=4, set_up_runtime() creates 4 runlist entries sharing the same XRT kernel handle and instruction binary (insts.bin):

for i in range(partition_N):
    self.add_to_runlist("gemm", "A", f"B_{i}", f"C_{i}")

All 4 entries use the same xrt_kernel object and insts_bo. The NPU's instruction sequence contains DMA descriptors that bind to buffer addresses on first execution. When the kernel is re-invoked with different B/C buffer objects, the NPU does not re-resolve the DMA addresses -- it reuses the cached descriptors from the first invocation.

Evidence

  • Partition 0: corr=0.99994 (correct)
  • Partitions 1-3: corr=0.74 (wrong -- not random, not C_0's result, but corrupted data)
  • Each partition as a standalone GEMM with its own AIEContext: all 4 produce corr=0.99994

What was tried (none fixed Bug 2)

Approach Result
AIEContext(use_runlist=False) (sequential kernel calls) Same -- partitions 1-3 wrong
Fresh get_kernel_handle() per partition Same -- partitions 1-3 wrong
Per-partition runlist swap (self.runlist = [runlist[i]]) Same -- partitions 1-3 wrong
Separate A buffer per partition Worse -- all 4 partitions wrong (buffer pool aliasing)
Separate AIEContext per partition (partition_N=1 each) Works -- all 4 correct

Why it's not caught by existing tests

  • No GEMM test uses partition_N > 1
  • The Llama app test (iron/applications/llama_3.2_1b/test.py) only checks returncode == 0 with no output correctness validation
  • The model generates coherent-looking text despite the bug because common English tokens have low vocab indices (within partition 0's range) and argmax is noise-tolerant

Environment

  • IRON: devel branch
  • mlir-aie: v1.2.1
  • XRT: /opt/xilinx/xrt
  • Device: Ryzen AI NPU (npu2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions