@@ -489,6 +489,7 @@ struct ggml_backend_opencl_context {
489489 cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
490490 cl_kernel kernel_relu;
491491 cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
492+ cl_kernel kernel_tri;
492493 cl_kernel kernel_fill;
493494 cl_kernel kernel_clamp;
494495 cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
@@ -793,6 +794,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
793794 GGML_LOG_CONT (" ." );
794795 }
795796
797+ // tri
798+ {
799+ #ifdef GGML_OPENCL_EMBED_KERNELS
800+ const std::string kernel_src {
801+ #include " tri.cl.h"
802+ };
803+ #else
804+ const std::string kernel_src = read_file (" tri.cl" );
805+ #endif
806+ cl_program prog =
807+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
808+
809+ CL_CHECK ((backend_ctx->kernel_tri = clCreateKernel (prog, " kernel_tri_f32" , &err), err));
810+ GGML_LOG_CONT (" ." );
811+
812+ CL_CHECK (clReleaseProgram (prog));
813+ }
814+
796815 // fill
797816 {
798817#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3205,6 +3224,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
32053224 default :
32063225 return false ;
32073226 }
3227+ case GGML_OP_TRI:
3228+ return op->type == GGML_TYPE_F32 && ggml_is_contiguous (op);
32083229 case GGML_OP_FILL:
32093230 return op->type == GGML_TYPE_F32 && ggml_is_contiguous (op);
32103231 case GGML_OP_CLAMP:
@@ -5965,6 +5986,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
59655986 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size_ptr, dst);
59665987}
59675988
5989+ static void ggml_cl_tri (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5990+ GGML_ASSERT (src0);
5991+ GGML_ASSERT (src0->extra );
5992+ GGML_ASSERT (dst);
5993+ GGML_ASSERT (dst->extra );
5994+
5995+ UNUSED (src1);
5996+
5997+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5998+
5999+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
6000+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
6001+
6002+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
6003+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
6004+
6005+ const int tri_type = ggml_get_op_params_i32 (dst, 0 );
6006+ const int64_t n = ggml_nelements (dst);
6007+ const int ne0 = dst->ne [0 ];
6008+ const int ne1 = dst->ne [1 ];
6009+
6010+ cl_kernel kernel = backend_ctx->kernel_tri ;
6011+
6012+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
6013+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
6014+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
6015+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
6016+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &n));
6017+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne0));
6018+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne1));
6019+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &tri_type));
6020+
6021+ size_t local_work_size[1 ] = { 256 };
6022+ size_t global_work_size[1 ] = { ((size_t )n + local_work_size[0 ] - 1 ) / local_work_size[0 ] * local_work_size[0 ] };
6023+
6024+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , global_work_size, local_work_size, dst);
6025+ }
6026+
59686027static void ggml_cl_fill (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
59696028 GGML_ASSERT (dst);
59706029 GGML_ASSERT (dst->extra );
@@ -10012,6 +10071,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
1001210071 }
1001310072 func = ggml_cl_glu;
1001410073 break ;
10074+ case GGML_OP_TRI:
10075+ if (!any_on_device) {
10076+ return false ;
10077+ }
10078+ func = ggml_cl_tri;
10079+ break ;
1001510080 case GGML_OP_FILL:
1001610081 if (!any_on_device) {
1001710082 return false ;
0 commit comments