Skip to content

Commit 3f190c1

Browse files
committed
opencl: flash attention for Adreno — fused kernels, quant KV, flash-decoding
Fused FA kernels for f16, native q8_0, and native q4_0 KV on Qualcomm Adreno. Covers prefill, n_q=1 decode, and a flash-decoding (K-split) path for long-context decode. - N_SPLIT splits DK/DV across threads in a WG; QK partials reduced via sub_group_shuffle_xor, with a 3-phase local-memory fallback when subgroup shuffle isn't available. - Flash-Decoding fires at n_kv >= 2048 on n_q <= 8 non-causal shapes, DK <= 128 (DK <= 64 for multi-q FD). - Asymmetric KV (-ctk X -ctv Y, X != Y) via on-GPU dequant of the quant side to F32. - Per-(dk,dv) lazy compilation. f16 q1 / q1_split share Q via __local to avoid private-array spill. - One-pass online softmax (FA-2 style) in q8_0 and q4_0 q1 / q1_split kernels: maintains per-thread (m_i, l_i, o_acc) on a single sweep, cross-thread merge rescales by alpha=exp(m_i - m_final). Eliminates the second K read of the original two-pass implementation. Measured on gpt-oss-20b at d=16k decode (Adreno X2-90): q8_0 +16.6%, q4_0 +17.5% end-to-end vs two-pass baseline. - Dim-table tuning across DK in {40, 64, 80, 96, 112, 128, 192, 256}, with a per-generation selector hook (X1, X2 stubs). Tested on Snapdragon X Elite (Adreno X1-85) and Snapdragon X2 Elite Extreme (Adreno X2-90), OpenCL 3.0.
1 parent 58e68df commit 3f190c1

12 files changed

Lines changed: 5612 additions & 319 deletions

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ set(GGML_OPENCL_KERNELS
169169
mul_mm_f16_f32_kq_kqv
170170
conv2d
171171
conv2d_f16_f32
172+
flash_attn_pre_f16
172173
flash_attn_f32_f16
174+
flash_attn_f32_q8_0
175+
flash_attn_f32_q4_0
173176
flash_attn_f16
174177
flash_attn_f32
175178
)

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 1842 additions & 159 deletions
Large diffs are not rendered by default.

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,196 @@ kernel void kernel_restore_block_q8_0(
583583
}
584584
}
585585

586+
// AoS q8_0 dequant → f16. One thread per 32-elem block.
587+
kernel void kernel_dequant_q8_0_f16_aos(
588+
global char * src,
589+
global half * dst,
590+
int n_blocks
591+
) {
592+
int blk = get_global_id(0);
593+
if (blk >= n_blocks) return;
594+
595+
global char * block = src + blk * (QK8_0 + 2);
596+
float d = vload_half(0, (global half *)block);
597+
global char * qs = block + 2;
598+
599+
global half * out = dst + blk * QK8_0;
600+
for (int i = 0; i < QK8_0; ++i) {
601+
out[i] = (half)(d * (float)qs[i]);
602+
}
603+
}
604+
605+
// View-aware AoS q8_0 → f32 dequant (f32/f32 FA path).
606+
kernel void kernel_dequant_q8_0_f32_view_aos(
607+
global char * src,
608+
ulong src_offset,
609+
ulong src_nb1,
610+
ulong src_nb2,
611+
ulong src_nb3,
612+
int nblk0,
613+
int ne1,
614+
int ne2,
615+
int ne3,
616+
global float * dst
617+
) {
618+
int blk_i0 = get_global_id(0);
619+
int i1 = get_global_id(1);
620+
int batch = get_global_id(2);
621+
622+
if (blk_i0 >= nblk0) return;
623+
if (i1 >= ne1) return;
624+
625+
int i2 = batch % ne2;
626+
int i3 = batch / ne2;
627+
if (i3 >= ne3) return;
628+
629+
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
630+
float d = vload_half(0, (global half *)block);
631+
global char * qs = block + 2;
632+
633+
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
634+
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
635+
636+
for (int i = 0; i < QK8_0; ++i) {
637+
out[i] = d * (float)qs[i];
638+
}
639+
}
640+
641+
// View-aware AoS q8_0 → f16 dequant. Rows tight, batch strides may be gapped.
642+
kernel void kernel_dequant_q8_0_f16_view_aos(
643+
global char * src,
644+
ulong src_offset,
645+
ulong src_nb1,
646+
ulong src_nb2,
647+
ulong src_nb3,
648+
int nblk0,
649+
int ne1,
650+
int ne2,
651+
int ne3,
652+
global half * dst
653+
) {
654+
int blk_i0 = get_global_id(0);
655+
int i1 = get_global_id(1);
656+
int batch = get_global_id(2);
657+
658+
if (blk_i0 >= nblk0) return;
659+
if (i1 >= ne1) return;
660+
661+
int i2 = batch % ne2;
662+
int i3 = batch / ne2;
663+
if (i3 >= ne3) return;
664+
665+
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
666+
float d = vload_half(0, (global half *)block);
667+
global char * qs = block + 2;
668+
669+
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
670+
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
671+
672+
for (int i = 0; i < QK8_0; ++i) {
673+
out[i] = (half)(d * (float)qs[i]);
674+
}
675+
}
676+
677+
// View-aware AoS q4_0 → f32 dequant (mirrors the q8_0 view variant).
678+
kernel void kernel_dequant_q4_0_f32_view_aos(
679+
global char * src,
680+
ulong src_offset,
681+
ulong src_nb1,
682+
ulong src_nb2,
683+
ulong src_nb3,
684+
int nblk0,
685+
int ne1,
686+
int ne2,
687+
int ne3,
688+
global float * dst
689+
) {
690+
int blk_i0 = get_global_id(0);
691+
int i1 = get_global_id(1);
692+
int batch = get_global_id(2);
693+
694+
if (blk_i0 >= nblk0) return;
695+
if (i1 >= ne1) return;
696+
697+
int i2 = batch % ne2;
698+
int i3 = batch / ne2;
699+
if (i3 >= ne3) return;
700+
701+
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
702+
float d = vload_half(0, (global half *)block);
703+
global uchar * qs = (global uchar *)(block + 2);
704+
705+
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
706+
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
707+
708+
for (int i = 0; i < QK4_0/2; ++i) {
709+
uchar byte = qs[i];
710+
int q0 = (int)(byte & 0x0F) - 8;
711+
int q1 = (int)(byte >> 4) - 8;
712+
out[i] = d * (float)q0;
713+
out[i + QK4_0/2] = d * (float)q1;
714+
}
715+
}
716+
717+
// View-aware AoS q4_0 → f16 dequant (mirrors the q8_0 view variant).
718+
kernel void kernel_dequant_q4_0_f16_view_aos(
719+
global char * src,
720+
ulong src_offset,
721+
ulong src_nb1,
722+
ulong src_nb2,
723+
ulong src_nb3,
724+
int nblk0,
725+
int ne1,
726+
int ne2,
727+
int ne3,
728+
global half * dst
729+
) {
730+
int blk_i0 = get_global_id(0);
731+
int i1 = get_global_id(1);
732+
int batch = get_global_id(2);
733+
734+
if (blk_i0 >= nblk0) return;
735+
if (i1 >= ne1) return;
736+
737+
int i2 = batch % ne2;
738+
int i3 = batch / ne2;
739+
if (i3 >= ne3) return;
740+
741+
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
742+
float d = vload_half(0, (global half *)block);
743+
global uchar * qs = (global uchar *)(block + 2);
744+
745+
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
746+
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
747+
748+
for (int i = 0; i < QK4_0/2; ++i) {
749+
uchar byte = qs[i];
750+
int q0 = (int)(byte & 0x0F) - 8;
751+
int q1 = (int)(byte >> 4) - 8;
752+
out[i] = (half)(d * (float)q0);
753+
out[i + QK4_0/2] = (half)(d * (float)q1);
754+
}
755+
}
756+
757+
// SoA q8_0 dequant; layout matches kernel_convert_block_q8_0.
758+
kernel void kernel_dequant_q8_0_f16_soa(
759+
global char * src_q,
760+
global char * src_d,
761+
global half * dst,
762+
int n_blocks
763+
) {
764+
int blk = get_global_id(0);
765+
if (blk >= n_blocks) return;
766+
767+
float d = vload_half(0, (global half *)src_d + blk);
768+
global char * qs = src_q + blk * QK8_0;
769+
770+
global half * out = dst + blk * QK8_0;
771+
for (int i = 0; i < QK8_0; ++i) {
772+
out[i] = (half)(d * (float)qs[i]);
773+
}
774+
}
775+
586776
kernel void kernel_restore_block_q8_0_trans(
587777
global uchar * src_q,
588778
global half * src_d,

ggml/src/ggml-opencl/kernels/exp.cl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ kernel void kernel_exp_f16(
4545
src0 = (global half*)((global char*)src0 + offset0);
4646
dst = (global half*)((global char*)dst + offsetd);
4747

48-
dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
48+
dst[get_global_id(0)] = convert_half(exp(convert_float(src0[get_global_id(0)])));
4949
}
5050

5151
kernel void kernel_exp_f16_4(
@@ -61,7 +61,7 @@ kernel void kernel_exp_f16_4(
6161
src0 = (global half4*)((global char*)src0 + offset0);
6262
dst = (global half4*)((global char*)dst + offsetd);
6363

64-
dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
64+
dst[get_global_id(0)] = convert_half4(exp(convert_float4(src0[get_global_id(0)])));
6565
}
6666

6767
kernel void kernel_exp_f32_nc(
@@ -120,6 +120,6 @@ kernel void kernel_exp_f16_nc(
120120
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
121121
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
122122

123-
*y = exp(*x);
123+
*y = convert_half(exp(convert_float(*x)));
124124
}
125125
}

ggml/src/ggml-opencl/kernels/expm1.cl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ kernel void kernel_expm1_f16(
3737
src0 = (global half*)((global char*)src0 + offset0);
3838
dst = (global half*)((global char*)dst + offsetd);
3939

40-
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
40+
{
41+
const float x = convert_float(src0[get_global_id(0)]);
42+
dst[get_global_id(0)] = convert_half(exp(x) - 1.0f);
43+
}
4144
}
4245

4346
kernel void kernel_expm1_f16_4(
@@ -49,7 +52,10 @@ kernel void kernel_expm1_f16_4(
4952
src0 = (global half4*)((global char*)src0 + offset0);
5053
dst = (global half4*)((global char*)dst + offsetd);
5154

52-
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
55+
{
56+
const float4 x = convert_float4(src0[get_global_id(0)]);
57+
dst[get_global_id(0)] = convert_half4(exp(x) - 1.0f);
58+
}
5359
}
5460

5561
kernel void kernel_expm1_f32_nc(
@@ -108,6 +114,7 @@ kernel void kernel_expm1_f16_nc(
108114
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
109115
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
110116

111-
*y = exp(*x) - 1.0f;
117+
const float fx = convert_float(*x);
118+
*y = convert_half(exp(fx) - 1.0f);
112119
}
113120
}

0 commit comments

Comments
 (0)