Skip to content

Commit d38ba4c

Browse files
committed
OpenCL: add TRI op support
1 parent 1c7cf94 commit d38ba4c

3 files changed

Lines changed: 98 additions & 0 deletions

File tree

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
5757
add
5858
add_id
5959
argsort
60+
tri
6061
fill
6162
clamp
6263
cpy

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ struct ggml_backend_opencl_context {
489489
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
490490
cl_kernel kernel_relu;
491491
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
492+
cl_kernel kernel_tri;
492493
cl_kernel kernel_fill;
493494
cl_kernel kernel_clamp;
494495
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
@@ -793,6 +794,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
793794
GGML_LOG_CONT(".");
794795
}
795796

797+
// tri
798+
{
799+
#ifdef GGML_OPENCL_EMBED_KERNELS
800+
const std::string kernel_src {
801+
#include "tri.cl.h"
802+
};
803+
#else
804+
const std::string kernel_src = read_file("tri.cl");
805+
#endif
806+
cl_program prog =
807+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
808+
809+
CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
810+
GGML_LOG_CONT(".");
811+
812+
CL_CHECK(clReleaseProgram(prog));
813+
}
814+
796815
// fill
797816
{
798817
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3205,6 +3224,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
32053224
default:
32063225
return false;
32073226
}
3227+
case GGML_OP_TRI:
3228+
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
32083229
case GGML_OP_FILL:
32093230
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
32103231
case GGML_OP_CLAMP:
@@ -5965,6 +5986,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
59655986
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
59665987
}
59675988

5989+
static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5990+
GGML_ASSERT(src0);
5991+
GGML_ASSERT(src0->extra);
5992+
GGML_ASSERT(dst);
5993+
GGML_ASSERT(dst->extra);
5994+
5995+
UNUSED(src1);
5996+
5997+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5998+
5999+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6000+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6001+
6002+
cl_ulong offset0 = extra0->offset + src0->view_offs;
6003+
cl_ulong offsetd = extrad->offset + dst->view_offs;
6004+
6005+
const int tri_type = ggml_get_op_params_i32(dst, 0);
6006+
const int64_t n = ggml_nelements(dst);
6007+
const int ne0 = dst->ne[0];
6008+
const int ne1 = dst->ne[1];
6009+
6010+
cl_kernel kernel = backend_ctx->kernel_tri;
6011+
6012+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6013+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6014+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
6015+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
6016+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n));
6017+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0));
6018+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1));
6019+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type));
6020+
6021+
size_t local_work_size[1] = { 256 };
6022+
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
6023+
6024+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
6025+
}
6026+
59686027
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
59696028
GGML_ASSERT(dst);
59706029
GGML_ASSERT(dst->extra);
@@ -10012,6 +10071,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
1001210071
}
1001310072
func = ggml_cl_glu;
1001410073
break;
10074+
case GGML_OP_TRI:
10075+
if (!any_on_device) {
10076+
return false;
10077+
}
10078+
func = ggml_cl_tri;
10079+
break;
1001510080
case GGML_OP_FILL:
1001610081
if (!any_on_device) {
1001710082
return false;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
//------------------------------------------------------------------------------
4+
// tri
5+
//------------------------------------------------------------------------------
6+
__kernel void kernel_tri_f32(
7+
global float * src0,
8+
ulong offset0,
9+
global float * dst,
10+
ulong offsetd,
11+
int n,
12+
int ne0,
13+
int ne1,
14+
int tri_type
15+
) {
16+
src0 = (global float*)((global char*)src0 + offset0);
17+
dst = (global float*)((global char*)dst + offsetd);
18+
19+
int idx = get_global_id(0);
20+
if (idx >= n) return;
21+
22+
int i0 = idx % ne0;
23+
int i1 = (idx / ne0) % ne1;
24+
25+
int keep = 0;
26+
if (tri_type == 0) keep = (i0 >= i1);
27+
else if (tri_type == 1) keep = (i0 > i1);
28+
else if (tri_type == 2) keep = (i0 <= i1);
29+
else keep = (i0 < i1);
30+
31+
dst[idx] = keep ? src0[idx] : 0.0f;
32+
}

0 commit comments

Comments
 (0)