Skip to content

Commit 2a02b08

Browse files
committed
implement extract_row on CSR matrices for OpenCL backend
1 parent b81d1e2 commit 2a02b08

4 files changed

Lines changed: 189 additions & 0 deletions

File tree

src/opencl/cl_algo_registry.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <core/registry.hpp>
3131
#include <core/top.hpp>
3232

33+
#include <opencl/cl_m_extract_row.hpp>
3334
#include <opencl/cl_m_reduce.hpp>
3435
#include <opencl/cl_mxmT_masked.hpp>
3536
#include <opencl/cl_mxv.hpp>
@@ -93,6 +94,11 @@ namespace spla {
9394
g_registry->add(MAKE_KEY_CL_0("mxmT_masked", INT), std::make_shared<Algo_mxmT_masked_cl<T_INT>>());
9495
g_registry->add(MAKE_KEY_CL_0("mxmT_masked", UINT), std::make_shared<Algo_mxmT_masked_cl<T_UINT>>());
9596
g_registry->add(MAKE_KEY_CL_0("mxmT_masked", FLOAT), std::make_shared<Algo_mxmT_masked_cl<T_FLOAT>>());
97+
98+
// algorthm m_extract_row
99+
g_registry->add(MAKE_KEY_CL_0("m_extract_row", INT), std::make_shared<Algo_m_extract_row_cl<T_INT>>());
100+
g_registry->add(MAKE_KEY_CL_0("m_extract_row", UINT), std::make_shared<Algo_m_extract_row_cl<T_UINT>>());
101+
g_registry->add(MAKE_KEY_CL_0("m_extract_row", FLOAT), std::make_shared<Algo_m_extract_row_cl<T_FLOAT>>());
96102
}
97103

98104
}// namespace spla

src/opencl/cl_m_extract_row.hpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/**********************************************************************************/
2+
/* This file is part of spla project */
3+
/* https://github.com/SparseLinearAlgebra/spla */
4+
/**********************************************************************************/
5+
/* MIT License */
6+
/* */
7+
/* Copyright (c) 2025 SparseLinearAlgebra */
8+
/* */
9+
/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10+
/* of this software and associated documentation files (the "Software"), to deal */
11+
/* in the Software without restriction, including without limitation the rights */
12+
/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13+
/* copies of the Software, and to permit persons to whom the Software is */
14+
/* furnished to do so, subject to the following conditions: */
15+
/* */
16+
/* The above copyright notice and this permission notice shall be included in all */
17+
/* copies or substantial portions of the Software. */
18+
/* */
19+
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20+
/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21+
/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22+
/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23+
/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24+
/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25+
/* SOFTWARE. */
26+
/**********************************************************************************/
27+
28+
#ifndef SPLA_CL_M_EXTRACT_ROW_HPP
29+
#define SPLA_CL_M_EXTRACT_ROW_HPP
30+
31+
#include <schedule/schedule_tasks.hpp>
32+
33+
#include <core/dispatcher.hpp>
34+
#include <core/registry.hpp>
35+
#include <core/tmatrix.hpp>
36+
#include <core/top.hpp>
37+
#include <core/tscalar.hpp>
38+
#include <core/ttype.hpp>
39+
#include <core/tvector.hpp>
40+
41+
#include <opencl/cl_counter.hpp>
42+
#include <opencl/cl_debug.hpp>
43+
#include <opencl/cl_formats.hpp>
44+
#include <opencl/cl_program_builder.hpp>
45+
#include <opencl/generated/auto_m_extract_row.hpp>
46+
47+
namespace spla {
48+
49+
template<typename T>
50+
class Algo_m_extract_row_cl final : public RegistryAlgo {
51+
public:
52+
~Algo_m_extract_row_cl() override = default;
53+
54+
std::string get_name() override {
55+
return "m_extract_row";
56+
}
57+
58+
std::string get_description() override {
59+
return "opencl extract row from matrix";
60+
}
61+
62+
Status execute(const DispatchContext& ctx) override {
63+
auto t = ctx.task.template cast_safe<ScheduleTask_m_extract_row>();
64+
65+
ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
66+
ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
67+
auto op_apply = t->op_apply.template cast_safe<TOpUnary<T, T>>();
68+
69+
r->validate_wd(FormatVector::AccDense);
70+
M->validate_rw(FormatMatrix::AccCsr);
71+
72+
auto* p_cl_r = r->template get<CLDenseVec<T>>();
73+
auto* p_cl_M = M->template get<CLCsr<T>>();
74+
auto* p_cl_acc = get_acc_cl();
75+
auto& queue = p_cl_acc->get_queue_default();
76+
77+
// get the row boundaries from M->Ap
78+
uint row_bounds[2];
79+
cl::Buffer cl_row_bounds(p_cl_acc->get_context(),
80+
CL_MEM_READ_ONLY | CL_MEM_HOST_READ_ONLY | CL_MEM_USE_HOST_PTR,
81+
sizeof(row_bounds), row_bounds);
82+
83+
queue.enqueueCopyBuffer(p_cl_M->Ap, cl_row_bounds, t->index * sizeof(uint), 0, sizeof(row_bounds));
84+
queue.finish();
85+
86+
std::shared_ptr<CLProgram> program;
87+
ensure_kernel(op_apply, program);
88+
89+
auto kernel = program->make_kernel("extract_row");
90+
kernel.setArg(0, p_cl_r->Ax);
91+
kernel.setArg(1, p_cl_M->Ax);
92+
kernel.setArg(2, p_cl_M->Aj);
93+
kernel.setArg(3, row_bounds[1]);
94+
95+
// amount of elements in the row
96+
const uint n = row_bounds[1] - row_bounds[0] - 1;
97+
98+
cl::NDRange global(p_cl_acc->get_default_wgs() * div_up_clamp(n, p_cl_acc->get_default_wgs(), 1u, 1024u));
99+
cl::NDRange local(p_cl_acc->get_default_wgs());
100+
queue.enqueueNDRangeKernel(kernel, cl::NDRange(row_bounds[0]), global, local);
101+
102+
return Status::Ok;
103+
}
104+
105+
private:
106+
void ensure_kernel(const ref_ptr<TOpUnary<T, T>>& op_apply, std::shared_ptr<CLProgram>& program) {
107+
CLProgramBuilder program_builder;
108+
program_builder
109+
.set_name("m_extract_row")
110+
.add_type("TYPE", get_ttype<T>().template as<Type>())
111+
.add_op("OP_APPLY", op_apply.template as<OpUnary>())
112+
.set_source(source_m_extract_row)
113+
.acquire();
114+
program = program_builder.get_program();
115+
}
116+
};
117+
118+
}// namespace spla
119+
120+
#endif//SPLA_CL_M_EXTRACT_ROW_HPP
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
////////////////////////////////////////////////////////////////////
2+
// Copyright (c) 2021 - 2025 SparseLinearAlgebra
3+
// Autogenerated file, do not modify
4+
////////////////////////////////////////////////////////////////////
5+
6+
#pragma once
7+
8+
static const char source_m_extract_row[] = R"(
9+
10+
11+
__kernel void extract_row(__global TYPE* g_rx,
12+
__global const TYPE* g_Ax,
13+
__global const uint* g_Aj,
14+
const uint n) {
15+
const uint gid = get_global_id(0);
16+
const uint gsize = get_global_size(0);
17+
18+
for (uint i = gid; i < n; i += gsize) {
19+
g_rx[g_Aj[i]] = OP_APPLY(g_Ax[i]);
20+
}
21+
}
22+
23+
)";
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**********************************************************************************/
2+
/* This file is part of spla project */
3+
/* https://github.com/SparseLinearAlgebra/spla */
4+
/**********************************************************************************/
5+
/* MIT License */
6+
/* */
7+
/* Copyright (c) 2025 SparseLinearAlgebra */
8+
/* */
9+
/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10+
/* of this software and associated documentation files (the "Software"), to deal */
11+
/* in the Software without restriction, including without limitation the rights */
12+
/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13+
/* copies of the Software, and to permit persons to whom the Software is */
14+
/* furnished to do so, subject to the following conditions: */
15+
/* */
16+
/* The above copyright notice and this permission notice shall be included in all */
17+
/* copies or substantial portions of the Software. */
18+
/* */
19+
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20+
/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21+
/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22+
/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23+
/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24+
/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25+
/* SOFTWARE. */
26+
/**********************************************************************************/
27+
28+
#include "common_def.cl"
29+
30+
__kernel void extract_row(__global TYPE* g_rx,
31+
__global const TYPE* g_Ax,
32+
__global const uint* g_Aj,
33+
const uint n) {
34+
const uint gid = get_global_id(0);
35+
const uint gsize = get_global_size(0);
36+
37+
for (uint i = gid; i < n; i += gsize) {
38+
g_rx[g_Aj[i]] = OP_APPLY(g_Ax[i]);
39+
}
40+
}

0 commit comments

Comments
 (0)