Skip to content

Commit b6b5f09

Browse files
unamedkrclaude
andcommitted
pillar1.5(R1): restore QK-norm for pure Qwen3 (R40 was over-broad)
R40 disabled QK-norm for all "qwen" arch GGUFs. That was correct for Qwen3.5/3.6 HYBRID (DeltaNet + self-attn, delta_n_heads > 0) — those degrade when QK-norm is applied to their 10 self-attn layers. But pure Qwen3 (0.6B..32B, self-attn only) REQUIRES q_norm/k_norm. Without them, long-prompt attention (pos>=1) has unnormalized Q·K scores, residual stream explodes at layer 2 (norm 5396 vs HF 10), output is UTF-8 garbage. Found via HF reference diff methodology (tools/pillar1/diff_layers.py, also added here): - Layer-by-layer cosine/L2 at pos=1 with 144-token input - Layer 0 cosine 0.98 at pos=0 but 0.74 at pos=1 → attention at pos>=1 broken - h2 norm: ours 5396 vs HF 10 → catastrophic residual stream - With TQ_FORCE_QK_NORM=1: h2 norm normalizes to ~11 (close to HF) Fix: restrict the QK-norm disable to `delta_n_heads > 0` only. Drop the over-broad GGUF-arch name match. Pure Qwen3 now applies q_norm/ k_norm per-head as HF does. Real-world output (per-token prefill, 50-word synthetic prompt): BEFORE: "lenameuously... catchØ�Williamson" UTF-8 garbage AFTER: " word11223: word3: Word length?" pattern-matching English Regression: 15/15 test_models + 4/4 test_tokenizer PASS. Known remaining: batched prefill path (tq_forward_batch) still broken independently. That path DOES apply QK-norm unconditionally (line 3615) but still produces garbage — separate bug for follow-on R2+. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1adc6d2 commit b6b5f09

2 files changed

Lines changed: 70 additions & 6 deletions

File tree

src/engine/tq_transformer.c

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,14 +1201,19 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
12011201
* TQ_NO_QK_NORM=1 forces off (diagnostic).
12021202
* TQ_FORCE_QK_NORM=1 forces on (Gemma fallback for Qwen if ever
12031203
* the convention is fixed). */
1204+
/* Pillar 1.5 R1 fix: the R40 arch-conditional disable was too broad.
1205+
* Pure Qwen3 (0.6B..32B, self-attn only) REQUIRES q_norm/k_norm —
1206+
* without them, long-prompt (pos>=1) attention corrupts via
1207+
* un-normalized Q·K scores, producing norm explosion at layer 2+
1208+
* and UTF-8 garbage output.
1209+
*
1210+
* Qwen3.5/3.6 HYBRID (DeltaNet + self-attn, delta_n_heads > 0) was
1211+
* empirically shown at R40 to degrade with QK-norm applied to the
1212+
* 10 self-attn layers. Keep that disabled. */
12041213
int _qknorm_disabled = (getenv("TQ_NO_QK_NORM") != NULL);
1205-
int _is_qwen = (c->delta_n_heads > 0); /* Qwen hybrid: always */
1206-
if (model->gguf_ctx) {
1207-
tq_gguf_ctx_t* gctx = (tq_gguf_ctx_t*)model->gguf_ctx;
1208-
if (strstr(gctx->arch, "qwen") != NULL) _is_qwen = 1;
1209-
}
1214+
int _is_qwen_hybrid = (c->delta_n_heads > 0); /* Qwen3.5/3.6 hybrid */
12101215
int _apply_qknorm = !_qknorm_disabled;
1211-
if (_is_qwen && !getenv("TQ_FORCE_QK_NORM")) _apply_qknorm = 0;
1216+
if (_is_qwen_hybrid && !getenv("TQ_FORCE_QK_NORM")) _apply_qknorm = 0;
12121217

12131218
if (layer->q_norm && _apply_qknorm) {
12141219
for (int h = 0; h < n_heads; h++) {

tools/pillar1/diff_layers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python3
2+
"""Layer-by-layer diff between HF reference and our engine's dumps.
3+
4+
Input:
5+
tools/pillar1/hf_dump_long.npz (emb, h0..h27, logits per-position)
6+
/tmp/qdump/*.bin (our engine's pos=143 dumps, raw float32)
7+
8+
Output: per-layer table of cosine, max_abs_diff, L2_relative."""
9+
import numpy as np, os, sys, glob
10+
11+
HF_NPZ = sys.argv[1] if len(sys.argv) > 1 else "tools/pillar1/hf_dump_long.npz"
12+
US_DIR = sys.argv[2] if len(sys.argv) > 2 else "/tmp/qdump"
13+
POS = int(sys.argv[3]) if len(sys.argv) > 3 else 143
14+
15+
hf = np.load(HF_NPZ)
16+
print(f"HF npz keys: {list(hf.keys())[:5]}... shape_h0={hf['h0'].shape}")
17+
print(f"Reading our dumps from {US_DIR} at position {POS}")
18+
print()
19+
print(f"{'slot':<12} {'dim':>6} {'our_norm':>10} {'hf_norm':>10} {'max_abs':>10} {'L2_rel':>10} {'cosine':>8}")
20+
print("-" * 70)
21+
22+
def read_bin(path):
23+
return np.fromfile(path, dtype=np.float32)
24+
25+
slots = ["emb"] + [f"h{i}" for i in range(28)] + ["post_norm"]
26+
for slot in slots:
27+
bin_path = os.path.join(US_DIR, f"{slot}.bin")
28+
if not os.path.exists(bin_path):
29+
continue
30+
ours = read_bin(bin_path)
31+
if slot not in hf.files:
32+
continue
33+
hf_arr = hf[slot]
34+
if hf_arr.ndim == 2:
35+
ref = hf_arr[POS] # last-position vector for this layer
36+
else:
37+
ref = hf_arr
38+
if ours.shape != ref.shape:
39+
print(f"{slot}: shape mismatch us={ours.shape} hf={ref.shape}")
40+
continue
41+
diff = ours - ref
42+
max_abs = np.max(np.abs(diff))
43+
l2 = np.linalg.norm(diff)
44+
hf_norm = np.linalg.norm(ref)
45+
us_norm = np.linalg.norm(ours)
46+
l2_rel = l2 / max(hf_norm, 1e-9)
47+
cos = np.dot(ours, ref) / max(us_norm * hf_norm, 1e-9)
48+
print(f"{slot:<12} {len(ours):>6} {us_norm:>10.3f} {hf_norm:>10.3f} {max_abs:>10.4f} {l2_rel:>10.4%} {cos:>8.4f}")
49+
50+
# Compare top-5 logits (our dump logits.bin is FP32 full-vocab)
51+
print()
52+
logits_path = os.path.join(US_DIR, "logits.bin")
53+
if os.path.exists(logits_path):
54+
ours_l = read_bin(logits_path)
55+
hf_l = hf["logits"][POS] if hf["logits"].ndim == 2 else hf["logits"]
56+
top5_us = np.argsort(-ours_l)[:5]
57+
top5_hf = np.argsort(-hf_l)[:5]
58+
print(f"HF top-5 logits: {[(int(t), f'{hf_l[t]:.2f}') for t in top5_hf]}")
59+
print(f"Us top-5 logits: {[(int(t), f'{ours_l[t]:.2f}') for t in top5_us]}")

0 commit comments

Comments
 (0)