Skip to content

Commit 2d02bb0

Browse files
committed
opencl: add basic q4_1 mm
1 parent abe9f01 commit 2d02bb0

3 files changed

Lines changed: 226 additions & 0 deletions

File tree

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ set(GGML_OPENCL_KERNELS
103103
gemv_moe_mxfp4_f32
104104
mul_mm_f32_f32_l4_lm
105105
mul_mm_f16_f32_l4_lm
106+
mul_mm_q4_1_f32_l4_lm
106107
mul_mm_q8_0_f32_l4_lm
107108
mul_mm_q6_k_f32_l4_lm
108109
mul_mm_q8_0_f32_8x4

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ struct ggml_backend_opencl_context {
567567
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
568568
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
569569
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
570+
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
570571
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
571572
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
572573

@@ -1400,6 +1401,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14001401
GGML_LOG_CONT(".");
14011402
}
14021403

1404+
// mul_mm_q4_1_f32_l4_lm
1405+
{
1406+
#ifdef GGML_OPENCL_EMBED_KERNELS
1407+
const std::string kernel_src {
1408+
#include "mul_mm_q4_1_f32_l4_lm.cl.h"
1409+
};
1410+
#else
1411+
const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl");
1412+
#endif
1413+
cl_program prog =
1414+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1415+
1416+
CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err));
1417+
GGML_LOG_CONT(".");
1418+
}
1419+
14031420
// mul_mm_q8_0_f32_l4_lm
14041421
{
14051422
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -9135,6 +9152,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
91359152
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
91369153
return;
91379154
}
9155+
case GGML_TYPE_Q4_1: {
9156+
if (ne11 < 32) {
9157+
break;
9158+
}
9159+
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
9160+
break;
9161+
}
9162+
9163+
kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;
9164+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
9165+
9166+
int batch_stride_a = ne00*ne01;
9167+
int batch_stride_b = ne10*ne11;
9168+
int batch_stride_d = ne0*ne1;
9169+
9170+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
9171+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
9172+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
9173+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
9174+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
9175+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
9176+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
9177+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
9178+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
9179+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
9180+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
9181+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
9182+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a
9183+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b
9184+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d
9185+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a));
9186+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b));
9187+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d));
9188+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2));
9189+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3));
9190+
9191+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
9192+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
9193+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
9194+
9195+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
9196+
return;
9197+
}
91389198
case GGML_TYPE_Q8_0: {
91399199
if (ne11 < 32) {
91409200
break;
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 8
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
13+
global uchar4 * src0_q,
14+
global half * src0_d,
15+
global half * src0_m,
16+
global float4 * src1,
17+
ulong offset1,
18+
global float * dst,
19+
ulong offsetd,
20+
21+
int ne00,
22+
int ne01,
23+
int ne02,
24+
int ne11,
25+
int ne12,
26+
27+
int stride_a,
28+
int stride_b,
29+
int stride_d,
30+
31+
int batch_stride_a,
32+
int batch_stride_b,
33+
int batch_stride_d,
34+
35+
int r2,
36+
int r3
37+
) {
38+
src1 = (global float4*)((global char*)src1 + offset1);
39+
dst = (global float *)((global char*)dst + offsetd);
40+
41+
local float buf_a[BM * BK];
42+
local float buf_b[BN * BK];
43+
44+
const int batch_idx = get_global_id(2);
45+
46+
const int i13 = batch_idx / ne12;
47+
const int i12 = batch_idx % ne12;
48+
49+
const int i03 = i13 / r3;
50+
const int i02 = i12 / r2;
51+
52+
const int batch_idx_a = i03 * ne02 + i02;
53+
54+
const int ir = get_group_id(0);
55+
const int ic = get_group_id(1);
56+
57+
const int tid = get_local_id(0);
58+
const int th_r = tid % (BM / TM);
59+
const int th_c = tid / (BM / TM);
60+
61+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
62+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
63+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
64+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
65+
66+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
67+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
68+
69+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
70+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
71+
72+
float sums[TM * TN];
73+
float cache_a[TM];
74+
float cache_b[TN];
75+
76+
for (int i = 0; i < TM * TN; i++) {
77+
sums[i] = 0.0f;
78+
}
79+
80+
for (int block = 0; block < ne00; block += BK) {
81+
for (int l = 0; l < BM; l += loadstride_a) {
82+
if (ir*BM + loadc_a + l < ne01) {
83+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
84+
int ib = idx / 4;
85+
int iqs = idx % 4;
86+
87+
float d = (float)src0_d[ib];
88+
float m = (float)src0_m[ib];
89+
global uchar4 * qs = src0_q + ib*4 + iqs;
90+
uchar4 q = *qs;
91+
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
92+
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
93+
94+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
95+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
96+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
97+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
98+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
99+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
100+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
101+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
102+
} else {
103+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
104+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
105+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
106+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
107+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
108+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
109+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
110+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
111+
}
112+
}
113+
114+
for (int l = 0; l < BN; l += loadstride_b) {
115+
if (ic*BN + loadc_b + l < ne11) {
116+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
117+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
118+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
119+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
120+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
121+
} else {
122+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
123+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
124+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
125+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
126+
}
127+
}
128+
129+
barrier(CLK_LOCAL_MEM_FENCE);
130+
131+
pos_a += BK / LOAD_VEC_A;
132+
pos_b += BK / LOAD_VEC_B;
133+
134+
for (int i = 0; i < BK; i++) {
135+
for (int j = 0; j < TM; j++) {
136+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
137+
}
138+
139+
for (int j = 0; j < TN; j++) {
140+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
141+
}
142+
143+
for (int cc = 0; cc < TN; cc++) {
144+
for (int cr = 0; cr < TM; cr++) {
145+
const int sums_idx = cc*TM + cr;
146+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
147+
}
148+
}
149+
}
150+
barrier(CLK_LOCAL_MEM_FENCE);
151+
}
152+
153+
const int dr = ir * BM + th_r * TM;
154+
const int dc = ic * BN + th_c * TN;
155+
156+
const int offsets = batch_idx * batch_stride_d;
157+
158+
for (int cc = 0; cc < TN; cc++) {
159+
for (int cr = 0; cr < TM; cr++) {
160+
if (dr + cr < ne01 && dc + cc < ne11) {
161+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
162+
}
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)