Skip to content

Commit 87e14cb

Browse files
unamedkrclaude
andcommitted
Add turbo_kv_5b: 5-bit (32-level) Lloyd-Max codebook, near-lossless KV
New TQ_TYPE_TURBO_KV_5B following the Variant F architecture (single-stage RHT + Lloyd-Max codebook + ‖x‖, no QJL). 32-level codebook adds one bit of precision per element vs turbo_kv_4b for the cost of 16 bytes per block. Llama 3.2 3B PPL on bench/data/ppl_1k.txt (FP32 baseline = 13.56): turbo_kv_4b 14.28 (+5.3%) turbo_kv_5b 13.60 (+0.34%) ← near-lossless ⭐ SmolLM2 135M PPL (FP32 baseline = 18.62): turbo_kv_4b 19.70 (+5.8%) turbo_kv_5b 18.94 (+1.7%) Block layout (88 bytes, vs 72 for 4b): norm(2) + residual_norm(2) + inv_std(2) + _pad(2) + mse_5bit(80) 128 elements * 5 bits = 640 bits = 80 bytes for indices Changes: - tq_codebook.c: extend codebook table to b=5, add 32 Lloyd-Max-Gaussian centroids (Max 1960 Table I), bounds check 1..5 - tq_types.h: TQ_TYPE_TURBO_KV_5B enum, block_tq_turbo_kv_5b struct, size assertion - tq_turbo_kv.c: pack_5bit/unpack_5bit helpers (5 bits/element, LSB-first bit-stream packing), quantize/dequantize/attention impls following the same Variant F pattern - tq_traits.c: register TQ_TRAITS[TQ_TYPE_TURBO_KV_5B], add format spec case - tools/quant.c: CLI parser accepts -k turbo_kv_5b - integrations/llamacpp/tq_kv_cache.cpp: GGML_TYPE_TQ_TURBO_KV_5B + table entry + wrappers + count bump - tests/test_turbo_kv.cpp: FormatSpec test updated to drop the HAS_RESIDUAL assertion (Variant F removed it from 3b/4b too) All 35 tests pass. Closes one of the follow-ups in issue #15. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5f6387b commit 87e14cb

7 files changed

Lines changed: 219 additions & 13 deletions

File tree

include/turboquant/tq_types.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ typedef enum {
5454
TQ_TYPE_TURBO_KV_1B = 10,/* TurboQuant KV: 1-bit Hamming (sign only) */
5555
TQ_TYPE_TURBO_KV_2B = 11,/* TurboQuant KV: 2-bit (1-bit codebook + 1-bit QJL) */
5656
TQ_TYPE_UNIFORM_3B= 12, /* Min-Max uniform 3-bit with sub-block scales */
57-
TQ_TYPE_COUNT = 13
57+
TQ_TYPE_TURBO_KV_5B = 13,/* TurboQuant KV: RHT + 5-bit Lloyd-Max codebook */
58+
TQ_TYPE_COUNT = 14
5859
} tq_type;
5960

6061
/* ============================================================
@@ -220,6 +221,22 @@ typedef struct {
220221
uint8_t mse_indices[TQ_BK * 3 / 8]; /* 3-bit packed codebook indices (48B) */
221222
} block_tq_turbo_kv_3b;
222223

224+
/* TurboQuant KV cache block: 5-bit variant (Variant F architecture)
225+
*
226+
* 5-bit (32-level) Lloyd-Max-Gaussian codebook on RHT-rotated values.
227+
* Same single-stage structure as turbo_kv_4b — no QJL residual.
228+
*
229+
* Layout: norm(2) + residual_norm(2) + inv_std(2) + _pad(2) + mse_5bit(80) = 88 bytes
230+
* 128 elements * 5 bits = 640 bits = 80 bytes for indices.
231+
*/
232+
typedef struct {
233+
uint16_t norm; /* L2 norm of original vector (fp16) */
234+
uint16_t residual_norm; /* unused (kept for layout symmetry) */
235+
uint16_t inv_std_fp16; /* per-block inv_std for codebook lookup */
236+
uint16_t _pad; /* alignment padding */
237+
uint8_t mse_indices[TQ_BK * 5 / 8]; /* 5-bit packed indices 0..31 (80B) */
238+
} block_tq_turbo_kv_5b;
239+
223240
/* TurboQuant KV cache block: 4-bit variant (Variant F: codebook-only, no QJL)
224241
*
225242
* Karpathy-loop ablation showed the QJL residual contributes ~0 to scores.
@@ -277,6 +294,7 @@ TQ_CHECK_SIZE(block_tq_uniform_3b, 4 * TQ_3B_NSUB + TQ_BK * 3 / 8);
277294
TQ_CHECK_SIZE(block_tq_mixed_4b8, 4 + TQ_MIXED_OUTLIERS + TQ_MIXED_OUTLIERS * 2 + TQ_BK / 2);
278295
TQ_CHECK_SIZE(block_tq_turbo_kv_3b, 8 + TQ_BK * 3 / 8);
279296
TQ_CHECK_SIZE(block_tq_turbo_kv_4b, 8 + TQ_BK / 2);
297+
TQ_CHECK_SIZE(block_tq_turbo_kv_5b, 8 + TQ_BK * 5 / 8);
280298
TQ_CHECK_SIZE(block_tq_turbo_kv_1b, 8 + TQ_BK / 8);
281299
TQ_CHECK_SIZE(block_tq_turbo_kv_2b, 8 + TQ_BK / 8 + TQ_BK / 8);
282300

integrations/llamacpp/tq_kv_cache.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ enum {
4545
GGML_TYPE_TQ_TURBO_KV_1B = GGML_TYPE_TQ_BASE + 10,
4646
GGML_TYPE_TQ_TURBO_KV_2B = GGML_TYPE_TQ_BASE + 11,
4747
GGML_TYPE_TQ_UNIFORM_3B = GGML_TYPE_TQ_BASE + 12,
48-
GGML_TYPE_TQ_COUNT = 13,
48+
GGML_TYPE_TQ_TURBO_KV_5B = GGML_TYPE_TQ_BASE + 13,
49+
GGML_TYPE_TQ_COUNT = 14,
4950
};
5051

5152
/* ============================================================
@@ -67,6 +68,7 @@ static int tq_to_ggml_type(tq_type type) {
6768
case TQ_TYPE_TURBO_KV_1B: return GGML_TYPE_TQ_TURBO_KV_1B;
6869
case TQ_TYPE_TURBO_KV_2B: return GGML_TYPE_TQ_TURBO_KV_2B;
6970
case TQ_TYPE_UNIFORM_3B: return GGML_TYPE_TQ_UNIFORM_3B;
71+
case TQ_TYPE_TURBO_KV_5B: return GGML_TYPE_TQ_TURBO_KV_5B;
7072
default: return -1;
7173
}
7274
}
@@ -86,6 +88,7 @@ static tq_type ggml_to_tq_type(int ggml_id) {
8688
case GGML_TYPE_TQ_TURBO_KV_1B: return TQ_TYPE_TURBO_KV_1B;
8789
case GGML_TYPE_TQ_TURBO_KV_2B: return TQ_TYPE_TURBO_KV_2B;
8890
case GGML_TYPE_TQ_UNIFORM_3B: return TQ_TYPE_UNIFORM_3B;
91+
case GGML_TYPE_TQ_TURBO_KV_5B: return TQ_TYPE_TURBO_KV_5B;
8992
default: return TQ_TYPE_COUNT;
9093
}
9194
}
@@ -151,6 +154,7 @@ TQ_GGML_WRAPPERS(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
151154
TQ_GGML_WRAPPERS(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
152155
TQ_GGML_WRAPPERS(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
153156
TQ_GGML_WRAPPERS(uniform_3b, TQ_TYPE_UNIFORM_3B)
157+
TQ_GGML_WRAPPERS(turbo_kv_5b, TQ_TYPE_TURBO_KV_5B)
154158

155159
/* ============================================================
156160
* vec_dot wrappers (quantized key . FP32 query -> scalar)
@@ -204,6 +208,7 @@ TQ_GGML_VEC_DOT(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
204208
TQ_GGML_VEC_DOT(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
205209
TQ_GGML_VEC_DOT(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
206210
TQ_GGML_VEC_DOT(uniform_3b, TQ_TYPE_UNIFORM_3B)
211+
TQ_GGML_VEC_DOT(turbo_kv_5b, TQ_TYPE_TURBO_KV_5B)
207212

208213
/* ============================================================
209214
* GGML type trait table
@@ -327,6 +332,14 @@ static const tq_ggml_type_trait TQ_GGML_TRAITS[GGML_TYPE_TQ_COUNT] = {
327332
tq_ggml_to_float_uniform_3b,
328333
tq_ggml_vec_dot_uniform_3b,
329334
},
335+
{
336+
"tq_turbo_kv_5b", GGML_TYPE_TQ_TURBO_KV_5B, TQ_TYPE_TURBO_KV_5B,
337+
sizeof(block_tq_turbo_kv_5b), TQ_BK,
338+
(float)sizeof(block_tq_turbo_kv_5b) * 8.0f / TQ_BK,
339+
tq_ggml_from_float_turbo_kv_5b,
340+
tq_ggml_to_float_turbo_kv_5b,
341+
tq_ggml_vec_dot_turbo_kv_5b,
342+
},
330343
};
331344

332345
#define TQ_GGML_NUM_TYPES (sizeof(TQ_GGML_TRAITS) / sizeof(TQ_GGML_TRAITS[0]))
@@ -418,6 +431,7 @@ tq_type tq_parse_kv_cache_type(const char* arg) {
418431
{ "tq-turbo-kv-3b", TQ_TYPE_TURBO_KV_3B },
419432
{ "turbokv3", TQ_TYPE_TURBO_KV_3B },
420433
{ "turbo_kv_4b", TQ_TYPE_TURBO_KV_4B },
434+
{ "turbo_kv_5b", TQ_TYPE_TURBO_KV_5B },
421435
{ "tq-turbo-kv-4b", TQ_TYPE_TURBO_KV_4B },
422436
{ "turbokv4", TQ_TYPE_TURBO_KV_4B },
423437
{ "turbo_kv_1b", TQ_TYPE_TURBO_KV_1B },

src/core/tq_codebook.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,36 @@ static const float CODEBOOK_4BIT[16] = {
3636
0.1284f, 0.3881f, 0.6568f, 0.9423f, 1.2562f, 1.6180f, 2.0690f, 2.7326f
3737
};
3838

39+
/* b=5 (32 levels): optimal Lloyd-Max for N(0,1).
40+
* Source: Max 1960 Table I, 32-level Gaussian quantizer output values.
41+
* MSE is roughly 4x lower than 4-bit; outer level shrinks to ~2.0 because
42+
* the additional levels concentrate in the body. */
43+
static const float CODEBOOK_5BIT[32] = {
44+
-1.9956f, -1.7900f, -1.6107f, -1.4493f, -1.3010f, -1.1631f, -1.0334f, -0.9104f,
45+
-0.7928f, -0.6795f, -0.5697f, -0.4626f, -0.3576f, -0.2543f, -0.1520f, -0.0506f,
46+
0.0506f, 0.1520f, 0.2543f, 0.3576f, 0.4626f, 0.5697f, 0.6795f, 0.7928f,
47+
0.9104f, 1.0334f, 1.1631f, 1.3010f, 1.4493f, 1.6107f, 1.7900f, 1.9956f
48+
};
49+
3950
/* Codebook table indexed by bits */
40-
static const float* const CODEBOOKS[5] = {
51+
static const float* const CODEBOOKS[6] = {
4152
NULL, /* 0 bits: unused */
4253
CODEBOOK_1BIT, /* 1 bit: 2 levels */
4354
CODEBOOK_2BIT, /* 2 bits: 4 levels */
4455
CODEBOOK_3BIT, /* 3 bits: 8 levels */
45-
CODEBOOK_4BIT /* 4 bits: 16 levels */
56+
CODEBOOK_4BIT, /* 4 bits: 16 levels */
57+
CODEBOOK_5BIT /* 5 bits: 32 levels */
4658
};
4759

48-
static const int CODEBOOK_SIZES[5] = {0, 2, 4, 8, 16};
60+
static const int CODEBOOK_SIZES[6] = {0, 2, 4, 8, 16, 32};
4961

5062
/* ============================================================
5163
* Codebook quantize: find nearest centroid for each element
5264
* ============================================================ */
5365

5466
void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
5567
int n, int bits, float inv_std) {
56-
if (!src || !dst_indices || bits < 1 || bits > 4 || n <= 0) return;
68+
if (!src || !dst_indices || bits < 1 || bits > 5 || n <= 0) return;
5769

5870
const float* centroids = CODEBOOKS[bits];
5971
int n_levels = CODEBOOK_SIZES[bits];
@@ -82,7 +94,7 @@ void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
8294

8395
void tq_codebook_dequantize(const uint8_t* indices, float* dst,
8496
int n, int bits, float inv_std) {
85-
if (!indices || !dst || bits < 1 || bits > 4 || n <= 0) return;
97+
if (!indices || !dst || bits < 1 || bits > 5 || n <= 0) return;
8698

8799
const float* centroids = CODEBOOKS[bits];
88100
float std_val = (inv_std > 1e-10f) ? (1.0f / inv_std) : 1.0f;
@@ -97,7 +109,7 @@ void tq_codebook_dequantize(const uint8_t* indices, float* dst,
97109
* ============================================================ */
98110

99111
const float* tq_codebook_centroids(int bits) {
100-
if (bits < 1 || bits > 4) return NULL;
112+
if (bits < 1 || bits > 5) return NULL;
101113
return CODEBOOKS[bits];
102114
}
103115

src/core/tq_traits.c

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ extern void tq_turbo_kv_4b_dequantize_ref(const void* src, float* dst, int n);
4848
extern void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv,
4949
float* scores, int seq_len, int head_dim);
5050

51+
extern void tq_turbo_kv_5b_quantize_ref(const float* src, void* dst, int n);
52+
extern void tq_turbo_kv_5b_dequantize_ref(const void* src, float* dst, int n);
53+
extern void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv,
54+
float* scores, int seq_len, int head_dim);
55+
5156
extern void tq_turbo_kv_1b_quantize_ref(const float* src, void* dst, int n);
5257
extern void tq_turbo_kv_1b_dequantize_ref(const void* src, float* dst, int n);
5358
extern void tq_turbo_kv_1b_attention_ref(const float* query, const void* kv,
@@ -158,7 +163,17 @@ tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
158163
.quantize = tq_turbo_kv_4b_quantize_ref,
159164
.dequantize = tq_turbo_kv_4b_dequantize_ref,
160165
.attention = tq_turbo_kv_4b_attention_ref,
161-
.residual_type = TQ_TYPE_QJL_1B,
166+
.residual_type = TQ_TYPE_COUNT, /* Variant F: no residual */
167+
},
168+
[TQ_TYPE_TURBO_KV_5B] = {
169+
.name = "turbo_kv_5b",
170+
.block_size = TQ_BK,
171+
.type_size = sizeof(block_tq_turbo_kv_5b),
172+
.bpe = (float)sizeof(block_tq_turbo_kv_5b) * 8.0f / TQ_BK,
173+
.quantize = tq_turbo_kv_5b_quantize_ref,
174+
.dequantize = tq_turbo_kv_5b_dequantize_ref,
175+
.attention = tq_turbo_kv_5b_attention_ref,
176+
.residual_type = TQ_TYPE_COUNT,
162177
},
163178
[TQ_TYPE_TURBO_KV_1B] = {
164179
.name = "turbo_kv_1b",
@@ -258,8 +273,9 @@ tq_format_spec_t tq_get_format_spec(tq_type type) {
258273
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 3;
259274
spec.flags = TQ_FLAG_HAS_RESIDUAL; break;
260275
case TQ_TYPE_TURBO_KV_4B:
261-
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 4;
262-
spec.flags = TQ_FLAG_HAS_RESIDUAL; break;
276+
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 4; break;
277+
case TQ_TYPE_TURBO_KV_5B:
278+
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 5; break;
263279
case TQ_TYPE_TURBO_KV_1B:
264280
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 1; break;
265281
case TQ_TYPE_TURBO_KV_2B:

src/core/tq_turbo_kv.c

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,145 @@ void tq_turbo_kv_2b_attention_ref(const float* query, const void* kv,
910910
scores[seq] = mse_score + qjl_correction;
911911
}
912912
}
913+
914+
/* ============================================================
915+
* TurboQuant KV 5-bit (Variant F architecture):
916+
* normalize -> RHT -> 5-bit (32-level) Lloyd-Max codebook on rotated values
917+
* Single-stage estimator, no QJL residual.
918+
* ============================================================ */
919+
920+
/* Pack 5-bit indices into a bit-stream, LSB-first.
921+
* 128 elems × 5 bits = 640 bits = 80 bytes. */
922+
static void pack_5bit(const uint8_t* indices, uint8_t* packed, int n) {
923+
int total_bytes = (n * 5 + 7) / 8;
924+
memset(packed, 0, (size_t)total_bytes);
925+
for (int i = 0; i < n; i++) {
926+
int bit_offset = i * 5;
927+
int byte_idx = bit_offset / 8;
928+
int bit_pos = bit_offset % 8;
929+
uint16_t val = (uint16_t)(indices[i] & 0x1F);
930+
packed[byte_idx] |= (uint8_t)(val << bit_pos);
931+
if (bit_pos > 3) {
932+
packed[byte_idx + 1] |= (uint8_t)(val >> (8 - bit_pos));
933+
}
934+
}
935+
}
936+
937+
static void unpack_5bit(const uint8_t* packed, uint8_t* indices, int n) {
938+
int total_bytes = (n * 5 + 7) / 8;
939+
for (int i = 0; i < n; i++) {
940+
int bit_offset = i * 5;
941+
int byte_idx = bit_offset / 8;
942+
int bit_pos = bit_offset % 8;
943+
uint16_t val = (uint16_t)packed[byte_idx];
944+
if (bit_pos > 3 && byte_idx + 1 < total_bytes) {
945+
val |= (uint16_t)packed[byte_idx + 1] << 8;
946+
}
947+
indices[i] = (uint8_t)((val >> bit_pos) & 0x1F);
948+
}
949+
}
950+
951+
void tq_turbo_kv_5b_quantize_ref(const float* src, void* dst, int n) {
952+
block_tq_turbo_kv_5b* block = (block_tq_turbo_kv_5b*)dst;
953+
int dim = n;
954+
if (dim > TQ_BK) dim = TQ_BK;
955+
956+
float norm_sq = 0.0f;
957+
for (int i = 0; i < dim; i++) norm_sq += src[i] * src[i];
958+
float norm = sqrtf(norm_sq);
959+
block->norm = tkv_fp32_to_fp16(norm);
960+
block->residual_norm = 0;
961+
block->_pad = 0;
962+
963+
float rotated[TQ_BK];
964+
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
965+
for (int i = 0; i < dim; i++) rotated[i] = src[i] * inv_norm;
966+
for (int i = dim; i < TQ_BK; i++) rotated[i] = 0.0f;
967+
968+
tq_rht_transform(rotated, dim, TKV_DEFAULT_SEED);
969+
970+
/* Variant F: max-abs scaling, no clipping */
971+
float max_abs = 0.0f;
972+
for (int i = 0; i < dim; i++) {
973+
float a = fabsf(rotated[i]);
974+
if (a > max_abs) max_abs = a;
975+
}
976+
if (max_abs < 1e-10f) max_abs = 1.0f;
977+
const float CENT_5BIT_MAX = 1.9956f;
978+
float inv_std = CENT_5BIT_MAX / max_abs;
979+
block->inv_std_fp16 = tkv_fp32_to_fp16(inv_std);
980+
981+
uint8_t indices[TQ_BK];
982+
tq_codebook_quantize(rotated, indices, dim, 5, inv_std);
983+
pack_5bit(indices, block->mse_indices, dim);
984+
}
985+
986+
static void dequant_mse_rotated_5bit(const block_tq_turbo_kv_5b* block,
987+
float* rotated, int dim) {
988+
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
989+
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
990+
uint8_t indices[TQ_BK] = {0};
991+
unpack_5bit(block->mse_indices, indices, dim);
992+
tq_codebook_dequantize(indices, rotated, dim, 5, inv_std);
993+
}
994+
995+
void tq_turbo_kv_5b_dequantize_ref(const void* src, float* dst, int n) {
996+
const block_tq_turbo_kv_5b* block = (const block_tq_turbo_kv_5b*)src;
997+
int dim = n;
998+
if (dim > TQ_BK) dim = TQ_BK;
999+
1000+
float norm = tkv_fp16_to_fp32(block->norm);
1001+
1002+
float rotated[TQ_BK];
1003+
dequant_mse_rotated_5bit(block, rotated, dim);
1004+
tq_rht_inverse(rotated, dim, TKV_DEFAULT_SEED);
1005+
1006+
for (int i = 0; i < dim; i++) dst[i] = rotated[i] * norm;
1007+
}
1008+
1009+
void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache,
1010+
float* scores, int seq_len, int head_dim) {
1011+
const block_tq_turbo_kv_5b* blocks_5b = (const block_tq_turbo_kv_5b*)kv_cache;
1012+
int dim = head_dim;
1013+
if (dim > TQ_BK) dim = TQ_BK;
1014+
1015+
/* Pre-rotate query once */
1016+
float q_rot[TQ_BK];
1017+
memcpy(q_rot, query, (size_t)dim * sizeof(float));
1018+
for (int i = dim; i < TQ_BK; i++) q_rot[i] = 0.0f;
1019+
tq_rht_transform(q_rot, dim, TKV_DEFAULT_SEED);
1020+
1021+
for (int seq = 0; seq < seq_len; seq++) {
1022+
const block_tq_turbo_kv_5b* block = &blocks_5b[seq];
1023+
float norm = tkv_fp16_to_fp32(block->norm);
1024+
1025+
float rotated[TQ_BK];
1026+
dequant_mse_rotated_5bit(block, rotated, dim);
1027+
1028+
float mse_dot = 0.0f;
1029+
#ifdef __ARM_NEON
1030+
{
1031+
float32x4_t acc0 = vdupq_n_f32(0.0f);
1032+
float32x4_t acc1 = vdupq_n_f32(0.0f);
1033+
float32x4_t acc2 = vdupq_n_f32(0.0f);
1034+
float32x4_t acc3 = vdupq_n_f32(0.0f);
1035+
int d = 0;
1036+
for (; d + 15 < dim; d += 16) {
1037+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
1038+
acc1 = vfmaq_f32(acc1, vld1q_f32(&q_rot[d + 4]), vld1q_f32(&rotated[d + 4]));
1039+
acc2 = vfmaq_f32(acc2, vld1q_f32(&q_rot[d + 8]), vld1q_f32(&rotated[d + 8]));
1040+
acc3 = vfmaq_f32(acc3, vld1q_f32(&q_rot[d + 12]), vld1q_f32(&rotated[d + 12]));
1041+
}
1042+
acc0 = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1043+
for (; d + 3 < dim; d += 4) {
1044+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
1045+
}
1046+
mse_dot = vaddvq_f32(acc0);
1047+
for (; d < dim; d++) mse_dot += q_rot[d] * rotated[d];
1048+
}
1049+
#else
1050+
for (int d = 0; d < dim; d++) mse_dot += q_rot[d] * rotated[d];
1051+
#endif
1052+
scores[seq] = norm * mse_dot;
1053+
}
1054+
}

tests/test_turbo_kv.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,15 +386,18 @@ TEST(TurboKV, TraitsTable) {
386386
}
387387

388388
TEST(TurboKV, FormatSpec) {
389+
/* Variant F: no residual stage in 3b/4b/5b — single-stage codebook only */
389390
tq_format_spec_t spec3 = tq_get_format_spec(TQ_TYPE_TURBO_KV_3B);
390391
EXPECT_EQ(spec3.algorithm, TQ_ALG_TURBO);
391392
EXPECT_EQ(spec3.key_bits, 3);
392-
EXPECT_TRUE(spec3.flags & TQ_FLAG_HAS_RESIDUAL);
393393

394394
tq_format_spec_t spec4 = tq_get_format_spec(TQ_TYPE_TURBO_KV_4B);
395395
EXPECT_EQ(spec4.algorithm, TQ_ALG_TURBO);
396396
EXPECT_EQ(spec4.key_bits, 4);
397-
EXPECT_TRUE(spec4.flags & TQ_FLAG_HAS_RESIDUAL);
397+
398+
tq_format_spec_t spec5 = tq_get_format_spec(TQ_TYPE_TURBO_KV_5B);
399+
EXPECT_EQ(spec5.algorithm, TQ_ALG_TURBO);
400+
EXPECT_EQ(spec5.key_bits, 5);
398401
}
399402

400403
/* ============================================================

tools/quant.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ static tq_type parse_kv_type(const char* s) {
8181
if (strcmp(s, "turbo_4b") == 0) return TQ_TYPE_TURBO_4B;
8282
if (strcmp(s, "turbo_kv_3b") == 0) return TQ_TYPE_TURBO_KV_3B;
8383
if (strcmp(s, "turbo_kv_4b") == 0) return TQ_TYPE_TURBO_KV_4B;
84+
if (strcmp(s, "turbo_kv_5b") == 0) return TQ_TYPE_TURBO_KV_5B;
8485
if (strcmp(s, "turbo_kv_1b") == 0) return TQ_TYPE_TURBO_KV_1B;
8586
if (strcmp(s, "qjl_1b") == 0) return TQ_TYPE_QJL_1B;
8687
if (strcmp(s, "mixed_4b8") == 0) return TQ_TYPE_MIXED_4B8;

0 commit comments

Comments
 (0)