Skip to content

Commit 80a8e87

Browse files
unamedkrclaude
andcommitted
Vulkan GPU auto-activation for KV cache operations
When built with TQ_BUILD_VULKAN=ON and a Vulkan device is available, KV cache quantize/attention functions are automatically routed to GPU compute shaders via runtime traits override. Changes: - tools/quant.c: call tq_init_vulkan_backend() on startup - tq_vulkan_init.c: add tq_vulkan_override_traits() — replaces CPU function pointers in TQ_TRAITS[] with Vulkan GPU versions - tq_traits.c: make TQ_TRAITS[] non-const for runtime override - tq_types.h: update extern declaration to match The full forward pass (matmul, FFN, norms) still runs on CPU. Vulkan handles KV quantize + dequant + attention kernels. 34/34 tests passing. Addresses #9 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b749134 commit 80a8e87

4 files changed

Lines changed: 41 additions & 4 deletions

File tree

include/turboquant/tq_types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ typedef struct {
174174
tq_type residual_type; /* pairing for composite types */
175175
} tq_type_traits_t;
176176

177-
/* Global traits table — initialized by tq_init() */
178-
extern const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT];
177+
/* Global traits table — GPU backends (Vulkan/Metal) override at runtime */
178+
extern tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT];
179179

180180
/* ============================================================
181181
* Cache block header (for paged cache)

src/backend/vulkan/tq_vulkan_init.c

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,39 @@ int tq_init_vulkan_backend(void) {
463463
if (tq_vk_create_pipeline_layout() != 0) return -1;
464464
if (tq_vk_create_pipelines() != 0) return -1;
465465

466-
printf("TQ Vulkan: Initialized on %s (subgroup size %u)\n",
466+
fprintf(stderr, "quant.cpp Vulkan: Initialized on %s (subgroup size %u)\n",
467467
g_vk_state.device_name, g_vk_state.subgroup_size);
468468

469469
g_vk_state.initialized = 1;
470+
471+
/* Override TQ_TRAITS with Vulkan-accelerated quantize/attention functions.
472+
* This makes KV cache operations automatically use GPU when available. */
473+
tq_vulkan_override_traits();
474+
470475
return 0;
471476
}
472477

478+
/* Override CPU traits with Vulkan GPU functions where available */
479+
void tq_vulkan_override_traits(void) {
480+
extern tq_type_traits_t TQ_TRAITS[];
481+
for (int i = 0; i < TQ_TYPE_COUNT; i++) {
482+
void* vk_quant = tq_vulkan_get_quantize_fn(i);
483+
void* vk_attn = tq_vulkan_get_attention_fn(i);
484+
if (vk_quant) {
485+
void (*fn)(const float*, void*, int);
486+
memcpy(&fn, &vk_quant, sizeof(fn));
487+
TQ_TRAITS[i].quantize = fn;
488+
fprintf(stderr, " Vulkan: GPU-accelerated quantize for %s\n", TQ_TRAITS[i].name);
489+
}
490+
if (vk_attn) {
491+
void (*fn)(const float*, const void*, float*, int, int);
492+
memcpy(&fn, &vk_attn, sizeof(fn));
493+
TQ_TRAITS[i].attention = fn;
494+
fprintf(stderr, " Vulkan: GPU-accelerated attention for %s\n", TQ_TRAITS[i].name);
495+
}
496+
}
497+
}
498+
473499
void tq_shutdown_vulkan_backend(void) {
474500
if (!g_vk_state.initialized) return;
475501

src/core/tq_traits.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ extern void tq_turbo_kv_2b_dequantize_ref(const void* src, float* dst, int n);
5858
extern void tq_turbo_kv_2b_attention_ref(const float* query, const void* kv,
5959
float* scores, int seq_len, int head_dim);
6060

61-
const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
61+
/* Non-const to allow runtime GPU backend override (Vulkan/Metal) */
62+
tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
6263
[TQ_TYPE_POLAR_3B] = {
6364
.name = "polar_3b",
6465
.block_size = TQ_BK,

tools/quant.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ int main(int argc, char** argv) {
270270
tq_quantize_weights(model);
271271
}
272272

273+
/* GPU backend detection and initialization */
274+
#ifdef TQ_BUILD_VULKAN
275+
{
276+
extern int tq_init_vulkan_backend(void);
277+
if (tq_init_vulkan_backend() == 0) {
278+
fprintf(stderr, "Vulkan backend: ready (KV cache quantization on GPU)\n");
279+
}
280+
}
281+
#endif
282+
273283
if (info_only) {
274284
tq_free_model(model);
275285
return 0;

0 commit comments

Comments
 (0)