Skip to content

Commit e411e3d

Browse files
unamedkrclaude
andcommitted
pillar1.5(R3) ★: NEOX-ordering RoPE for pure Qwen3 + Qwen3.6 batched
Root cause of long-prompt UTF-8 garbage on Qwen3 family: RoPE ordering mismatch with HF reference. - llama.cpp: LLM_ARCH_QWEN3 / QWEN3MOE → LLAMA_ROPE_TYPE_NEOX (half-split pairs: q[i] pairs with q[i+head_dim/2]) - Our engine's tq_rope and batched-prefill RoPE both used LLaMA-style interleaved pairs (q[2i], q[2i+1]) - R34 fixed this ONLY for the partial-rotary path (Qwen3.5/3.6 hybrid); pure Qwen3 (full rotary) and tq_forward_batch were never converted. Changes: 1. New tq_rope_neox() in tq_ops.c — half-split variant with same TLS sin/cos caching as tq_rope. 2. tq_engine.h exports the new entry point. 3. Per-token self_attn_forward full-rotary else branch: detect Qwen3 via gguf arch string or delta_n_heads > 0, dispatch to tq_rope_neox. TQ_ROPE_PAIRS=1 opt-out for legacy Qwen2/LLaMA. 4. tq_forward_batch full-rotary RoPE path: same detection and half-split rotation for learned-freq (rope_freqs) and fallback (tq_rope_neox) branches. Real-world (Qwen3-0.6B Q4, 50-word synthetic input): BEFORE: "lenameuously...catchØ�Williamson" UTF-8 garbage AFTER (batched): " Let me try to understand this" AFTER (per-token): " ... and so on... etc. So, the problem is..." Pillar 1 R3 BPE fix was necessary (tokens correct) but not sufficient (RoPE still wrong). Together, pure Qwen3 + Qwen3.5-4B long-prompt coherence restored. Qwen3.6-35B short prompts remain coherent; long-prompt partial coherence needs further investigation (likely DeltaNet-specific accumulation issue separate from RoPE). Regression: 15/15 test_models + 4/4 test_tokenizer PASS. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b6b5f09 commit e411e3d

3 files changed

Lines changed: 117 additions & 12 deletions

File tree

include/turboquant/tq_engine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,8 @@ void tq_matmul_rht_q4q2(float* out, const float* x,
652652
void tq_rmsnorm(float* out, const float* x, const float* weight, int n, float eps);
653653
void tq_rope(float* q, float* k, int pos, int head_dim,
654654
int n_heads, int n_kv_heads, float freq_base);
655+
void tq_rope_neox(float* q, float* k, int pos, int head_dim,
656+
int n_heads, int n_kv_heads, float freq_base);
655657
void tq_silu(float* x, int n);
656658
void tq_gelu_tanh(float* x, int n);
657659
void tq_softmax(float* x, int n);

src/engine/tq_ops.c

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,8 +1698,11 @@ void tq_rmsnorm(float* out, const float* x, const float* weight, int n, float ep
16981698
* Rotary Positional Embedding (RoPE)
16991699
*
17001700
* Applies rotation to pairs (q[2i], q[2i+1]) based on position.
1701-
* Compatible with LLaMA / Qwen RoPE convention.
1702-
* ============================================================ */
1701+
* Compatible with LLaMA / Qwen2 RoPE convention.
1702+
*
1703+
* NOTE: Qwen3 family (pure Qwen3 AND hybrid Qwen3.5/3.6) uses
1704+
* LLAMA_ROPE_TYPE_NEOX / IMROPE — half-split pairs (q[i], q[i+half]).
1705+
* Use tq_rope_neox for those. ============================================================ */
17031706
void tq_rope(float* q, float* k, int pos, int head_dim,
17041707
int n_heads, int n_kv_heads, float freq_base) {
17051708
/* TLS sin/cos cache keyed on (pos, freq_base, head_dim). Identical
@@ -1751,6 +1754,57 @@ void tq_rope(float* q, float* k, int pos, int head_dim,
17511754
}
17521755
}
17531756

1757+
/* ============================================================
1758+
* NEOX-style RoPE (Pillar 1.5 R3): half-split pairs (q[i], q[i+half]).
1759+
* llama.cpp maps Qwen3 / Qwen3MOE / Qwen35 / Qwen35MOE → NEOX/IMROPE.
1760+
* LLaMA 2 / Qwen2 use tq_rope (interleaved pairs).
1761+
* Missing this on batched prefill + full-rotary per-token path was
1762+
* root cause of Qwen3 long-prompt UTF-8 garbage (R7/R8 bug).
1763+
* ============================================================ */
1764+
void tq_rope_neox(float* q, float* k, int pos, int head_dim,
1765+
int n_heads, int n_kv_heads, float freq_base) {
1766+
int half = head_dim / 2;
1767+
static __thread float tls_cos[256];
1768+
static __thread float tls_sin[256];
1769+
static __thread int tls_pos = -1;
1770+
static __thread float tls_base = 0.0f;
1771+
static __thread int tls_dim = 0;
1772+
if (half <= 256 &&
1773+
(tls_pos != pos || tls_base != freq_base || tls_dim != head_dim)) {
1774+
for (int i = 0; i < half; i++) {
1775+
float freq = 1.0f / powf(freq_base, 2.0f * i / head_dim);
1776+
float theta = pos * freq;
1777+
tls_cos[i] = cosf(theta);
1778+
tls_sin[i] = sinf(theta);
1779+
}
1780+
tls_pos = pos;
1781+
tls_base = freq_base;
1782+
tls_dim = head_dim;
1783+
}
1784+
for (int h = 0; h < n_heads; h++) {
1785+
float* qh = q + h * head_dim;
1786+
for (int i = 0; i < half; i++) {
1787+
float cos_t = tls_cos[i];
1788+
float sin_t = tls_sin[i];
1789+
float q0 = qh[i];
1790+
float q1 = qh[i + half];
1791+
qh[i] = q0 * cos_t - q1 * sin_t;
1792+
qh[i + half] = q0 * sin_t + q1 * cos_t;
1793+
}
1794+
}
1795+
for (int h = 0; h < n_kv_heads; h++) {
1796+
float* kh = k + h * head_dim;
1797+
for (int i = 0; i < half; i++) {
1798+
float cos_t = tls_cos[i];
1799+
float sin_t = tls_sin[i];
1800+
float k0 = kh[i];
1801+
float k1 = kh[i + half];
1802+
kh[i] = k0 * cos_t - k1 * sin_t;
1803+
kh[i + half] = k0 * sin_t + k1 * cos_t;
1804+
}
1805+
}
1806+
}
1807+
17541808
/* ============================================================
17551809
* SiLU activation: x[i] = x[i] * sigmoid(x[i])
17561810
* Also known as swish activation.

src/engine/tq_transformer.c

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,23 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
15511551
}
15521552
}
15531553
} else {
1554-
tq_rope(s->q, s->k, pos, head_dim, n_heads, n_kv_heads, rope_base);
1554+
/* Pillar 1.5 R3: Qwen3 / Qwen3MOE / Qwen3.5 / Qwen3.6 all use
1555+
* NEOX-ordering RoPE (llama.cpp: LLM_ARCH_QWEN3* → ROPE_NEOX).
1556+
* Detect via GGUF arch string or delta_n_heads (hybrid family).
1557+
* Opt-out: TQ_ROPE_PAIRS=1 reverts to LLaMA pairs for legacy. */
1558+
int use_neox = 0;
1559+
if (model->gguf_ctx) {
1560+
tq_gguf_ctx_t* gctx = (tq_gguf_ctx_t*)model->gguf_ctx;
1561+
if (strstr(gctx->arch, "qwen3") != NULL
1562+
|| strstr(gctx->arch, "qwen35") != NULL) use_neox = 1;
1563+
}
1564+
if (c->delta_n_heads > 0) use_neox = 1; /* Qwen3.5/3.6 hybrid */
1565+
if (getenv("TQ_ROPE_PAIRS")) use_neox = 0;
1566+
if (use_neox) {
1567+
tq_rope_neox(s->q, s->k, pos, head_dim, n_heads, n_kv_heads, rope_base);
1568+
} else {
1569+
tq_rope(s->q, s->k, pos, head_dim, n_heads, n_kv_heads, rope_base);
1570+
}
15551571
}
15561572
}
15571573

@@ -3613,6 +3629,10 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
36133629
for (int i = 0; i < kv_dim; i++) VB[(size_t)n * kv_dim + i] += layer->v_bias[i];
36143630
}
36153631
/* 2b. QK-norm (Qwen3 — NULL for Llama). */
3632+
if (l == 0 && pos_start == 0 && getenv("TQ_DEBUG_PREFILL")) {
3633+
fprintf(stderr, "[batch-qknorm] L0 q_norm=%p k_norm=%p\n",
3634+
(void*)layer->q_norm, (void*)layer->k_norm);
3635+
}
36163636
if (layer->q_norm) {
36173637
for (int n = 0; n < N; n++) {
36183638
for (int h = 0; h < c->n_heads; h++) {
@@ -3633,7 +3653,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
36333653
/* 3. RoPE + KV cache write (per-token).
36343654
* Mirror tq_forward's RoPE selection: if model->rope_freqs is set
36353655
* (Llama 3.x learned RoPE scaling, 64 freq factors), apply per-pair
3636-
* factor; otherwise plain interleaved RoPE. */
3656+
* factor; otherwise plain RoPE (NEOX for Qwen3, LLaMA pairs for others).
3657+
*
3658+
* Pillar 1.5 R3: Qwen3 family uses NEOX-ordering per llama.cpp
3659+
* (LLM_ARCH_QWEN3 -> ROPE_NEOX). Without this on batched prefill,
3660+
* long-prompt attention was corrupted -> UTF-8 garbage output. */
3661+
int batch_use_neox = 0;
3662+
if (model->gguf_ctx) {
3663+
tq_gguf_ctx_t* gctx = (tq_gguf_ctx_t*)model->gguf_ctx;
3664+
if (strstr(gctx->arch, "qwen3") != NULL
3665+
|| strstr(gctx->arch, "qwen35") != NULL) batch_use_neox = 1;
3666+
}
3667+
if (c->delta_n_heads > 0) batch_use_neox = 1;
3668+
if (getenv("TQ_ROPE_PAIRS")) batch_use_neox = 0;
36373669
for (int n = 0; n < N; n++) {
36383670
float* qn = QB + (size_t)n * q_dim;
36393671
float* kn = KB + (size_t)n * kv_dim;
@@ -3651,9 +3683,15 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
36513683
float freq = base / model->rope_freqs[i];
36523684
float theta = pos * freq;
36533685
float ct = cosf(theta), st = sinf(theta);
3654-
float q0 = qh[2*i], q1 = qh[2*i+1];
3655-
qh[2*i] = q0 * ct - q1 * st;
3656-
qh[2*i+1] = q0 * st + q1 * ct;
3686+
if (batch_use_neox) {
3687+
float q0 = qh[i], q1 = qh[i + rope_pairs];
3688+
qh[i] = q0 * ct - q1 * st;
3689+
qh[i + rope_pairs] = q0 * st + q1 * ct;
3690+
} else {
3691+
float q0 = qh[2*i], q1 = qh[2*i+1];
3692+
qh[2*i] = q0 * ct - q1 * st;
3693+
qh[2*i+1] = q0 * st + q1 * ct;
3694+
}
36573695
}
36583696
}
36593697
for (int h = 0; h < c->n_kv_heads; h++) {
@@ -3663,14 +3701,25 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
36633701
float freq = base / model->rope_freqs[i];
36643702
float theta = pos * freq;
36653703
float ct = cosf(theta), st = sinf(theta);
3666-
float k0 = kh[2*i], k1 = kh[2*i+1];
3667-
kh[2*i] = k0 * ct - k1 * st;
3668-
kh[2*i+1] = k0 * st + k1 * ct;
3704+
if (batch_use_neox) {
3705+
float k0 = kh[i], k1 = kh[i + rope_pairs];
3706+
kh[i] = k0 * ct - k1 * st;
3707+
kh[i + rope_pairs] = k0 * st + k1 * ct;
3708+
} else {
3709+
float k0 = kh[2*i], k1 = kh[2*i+1];
3710+
kh[2*i] = k0 * ct - k1 * st;
3711+
kh[2*i+1] = k0 * st + k1 * ct;
3712+
}
36693713
}
36703714
}
36713715
} else {
3672-
tq_rope(qn, kn, pos, c->head_dim, c->n_heads, c->n_kv_heads,
3673-
c->rope_freq_base);
3716+
if (batch_use_neox) {
3717+
tq_rope_neox(qn, kn, pos, c->head_dim, c->n_heads, c->n_kv_heads,
3718+
c->rope_freq_base);
3719+
} else {
3720+
tq_rope(qn, kn, pos, c->head_dim, c->n_heads, c->n_kv_heads,
3721+
c->rope_freq_base);
3722+
}
36743723
}
36753724
if (n == 0 && l == 0 && dbg) {
36763725
fprintf(stderr, "[batch] L0 QB (post-RoPE) tok0 [0:8] = ");

0 commit comments

Comments
 (0)