Skip to content

Commit bf46e93

Browse files
authored
Add matrix concatenation operations to kernel generator (#1867)
* implemented append_row and append_col * fixed doxygen
1 parent 1ad25a0 commit bf46e93

3 files changed

Lines changed: 516 additions & 0 deletions

File tree

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115

116116
#include <stan/math/opencl/kernel_generator/load.hpp>
117117
#include <stan/math/opencl/kernel_generator/scalar.hpp>
118+
#include <stan/math/opencl/kernel_generator/append.hpp>
118119
#include <stan/math/opencl/kernel_generator/binary_operation.hpp>
119120
#include <stan/math/opencl/kernel_generator/unary_function_cl.hpp>
120121
#include <stan/math/opencl/kernel_generator/unary_operation_cl.hpp>
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
2+
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/matrix_cl_view.hpp>
6+
#include <stan/math/opencl/err.hpp>
7+
#include <stan/math/prim/meta.hpp>
8+
#include <stan/math/opencl/kernel_generator/type_str.hpp>
9+
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
10+
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
11+
#include <stan/math/opencl/kernel_generator/scalar.hpp>
12+
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
13+
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
14+
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp>
15+
#include <algorithm>
16+
#include <string>
17+
#include <tuple>
18+
#include <type_traits>
19+
#include <set>
20+
#include <utility>
21+
22+
namespace stan {
23+
namespace math {
24+
25+
/** \addtogroup opencl_kernel_generator
26+
* @{
27+
*/
28+
29+
/**
30+
* Represents appending of rows in kernel generator expressions.
31+
* @tparam T_a type of first argument
32+
* @tparam T_b type of second argument
33+
*/
34+
template <typename T_a, typename T_b>
35+
class append_row_ : public operation_cl<append_row_<T_a, T_b>,
36+
common_scalar_t<T_a, T_b>, T_a, T_b> {
37+
public:
38+
using Scalar = common_scalar_t<T_a, T_b>;
39+
using base = operation_cl<append_row_<T_a, T_b>, Scalar, T_a, T_b>;
40+
using base::var_name;
41+
42+
protected:
43+
using base::arguments_;
44+
45+
public:
46+
/**
47+
* Constructor
48+
* @param a first argument
49+
* @param b second argument
50+
*/
51+
append_row_(T_a&& a, T_b&& b) // NOLINT
52+
: base(std::forward<T_a>(a), std::forward<T_b>(b)) {
53+
if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
54+
check_size_match("append_row", "Columns of ", "a", a.cols(),
55+
"columns of ", "b", b.cols());
56+
}
57+
if (a.rows() < 0) {
58+
invalid_argument("append_row", "Rows of a", a.rows(),
59+
"should be nonnegative!");
60+
}
61+
if (b.rows() < 0) {
62+
invalid_argument("append_row", "Rows of b", b.rows(),
63+
"should be nonnegative!");
64+
}
65+
}
66+
67+
/**
68+
* Creates a deep copy of this expression.
69+
* @return copy of \c *this
70+
*/
71+
inline auto deep_copy() const {
72+
auto&& a_copy = this->template get_arg<0>().deep_copy();
73+
auto&& b_copy = this->template get_arg<1>().deep_copy();
74+
return append_row_<std::remove_reference_t<decltype(a_copy)>,
75+
std::remove_reference_t<decltype(b_copy)>>{
76+
std::move(a_copy), std::move(b_copy)};
77+
}
78+
79+
/**
80+
* Generates kernel code for this and nested expressions.
81+
* @param[in,out] generated set of (pointer to) already generated operations
82+
* @param name_gen name generator for this kernel
83+
* @param i row index variable name
84+
* @param j column index variable name
85+
* @param view_handled whether caller already handled matrix view
86+
* @return part of kernel with code for this and nested expressions
87+
*/
88+
inline kernel_parts get_kernel_parts(
89+
std::set<const operation_cl_base*>& generated, name_generator& name_gen,
90+
const std::string& i, const std::string& j, bool view_handled) const {
91+
kernel_parts res{};
92+
if (generated.count(this) == 0) {
93+
var_name = name_gen.generate();
94+
generated.insert(this);
95+
std::string i_b = "(" + i + " - " + var_name + "_first_rows)";
96+
kernel_parts parts_a = this->template get_arg<0>().get_kernel_parts(
97+
generated, name_gen, i, j, true);
98+
kernel_parts parts_b = this->template get_arg<1>().get_kernel_parts(
99+
generated, name_gen, i_b, j, true);
100+
res = parts_a + parts_b;
101+
res.body = type_str<Scalar>() + " " + var_name + ";\n"
102+
"if("+ i +" < " + var_name + "_first_rows){\n"
103+
+ parts_a.body +
104+
var_name + " = " + this->template get_arg<0>().var_name + ";\n"
105+
"} else{\n"
106+
+ parts_b.body +
107+
var_name + " = " + this->template get_arg<1>().var_name + ";\n"
108+
"}\n";
109+
res.args += "int " + var_name + "_first_rows, ";
110+
}
111+
return res;
112+
}
113+
114+
/**
115+
* Sets kernel arguments for this and nested expressions.
116+
* @param[in,out] generated set of expressions that already set their kernel
117+
* arguments
118+
* @param kernel kernel to set arguments on
119+
* @param[in,out] arg_num consecutive number of the first argument to set.
120+
* This is incremented for each argument set by this function.
121+
*/
122+
inline void set_args(std::set<const operation_cl_base*>& generated,
123+
cl::Kernel& kernel, int& arg_num) const {
124+
if (generated.count(this) == 0) {
125+
generated.insert(this);
126+
this->template get_arg<0>().set_args(generated, kernel, arg_num);
127+
this->template get_arg<1>().set_args(generated, kernel, arg_num);
128+
kernel.setArg(arg_num++, this->template get_arg<0>().rows());
129+
}
130+
}
131+
132+
/**
133+
* Number of rows of a matrix that would be the result of evaluating this
134+
* expression.
135+
* @return number of rows
136+
*/
137+
inline int rows() const {
138+
return this->template get_arg<0>().rows()
139+
+ this->template get_arg<1>().rows();
140+
}
141+
142+
/**
143+
* Determine indices of extreme sub- and superdiagonals written.
144+
* @return pair of indices - bottom and top diagonal
145+
*/
146+
inline std::pair<int, int> extreme_diagonals() const {
147+
std::pair<int, int> a_diags
148+
= this->template get_arg<0>().extreme_diagonals();
149+
std::pair<int, int> b_diags
150+
= this->template get_arg<1>().extreme_diagonals();
151+
return {b_diags.first - this->template get_arg<0>().rows(), a_diags.second};
152+
}
153+
};
154+
155+
/**
156+
* Stack the rows of the first argument on top of the second argument.
157+
*
158+
* @tparam Ta type of first argument
159+
* @tparam Ta type of second argument
160+
* @param a First argument
161+
* @param b Second argument
162+
* @return Stacked arguments
163+
*/
164+
template <typename Ta, typename Tb,
165+
typename = require_all_valid_expressions_and_none_scalar_t<Ta, Tb>>
166+
inline auto append_row(Ta&& a, Tb&& b) {
167+
auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
168+
auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();
169+
return append_row_<std::remove_reference_t<decltype(a_operation)>,
170+
std::remove_reference_t<decltype(b_operation)>>(
171+
std::move(a_operation), std::move(b_operation));
172+
}
173+
174+
/**
175+
* Represents appending of cols in kernel generator expressions.
176+
* @tparam T_a type of first argument
177+
* @tparam T_b type of second argument
178+
*/
179+
template <typename T_a, typename T_b>
180+
class append_col_ : public operation_cl<append_col_<T_a, T_b>,
181+
common_scalar_t<T_a, T_b>, T_a, T_b> {
182+
public:
183+
using Scalar = common_scalar_t<T_a, T_b>;
184+
using base = operation_cl<append_col_<T_a, T_b>, Scalar, T_a, T_b>;
185+
using base::var_name;
186+
187+
protected:
188+
using base::arguments_;
189+
190+
public:
191+
/**
192+
* Constructor
193+
* @param a first argument
194+
* @param b second argument
195+
*/
196+
append_col_(T_a&& a, T_b&& b) // NOLINT
197+
: base(std::forward<T_a>(a), std::forward<T_b>(b)) {
198+
if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
199+
check_size_match("append_col", "Rows of ", "a", a.rows(), "rows of ", "b",
200+
b.rows());
201+
}
202+
if (a.cols() < 0) {
203+
invalid_argument("append_col", "Columns of a", a.cols(),
204+
"should be nonnegative!");
205+
}
206+
if (b.cols() < 0) {
207+
invalid_argument("append_col", "Columns of b", b.cols(),
208+
"should be nonnegative!");
209+
}
210+
}
211+
212+
/**
213+
* Creates a deep copy of this expression.
214+
* @return copy of \c *this
215+
*/
216+
inline auto deep_copy() const {
217+
auto&& a_copy = this->template get_arg<0>().deep_copy();
218+
auto&& b_copy = this->template get_arg<1>().deep_copy();
219+
return append_col_<std::remove_reference_t<decltype(a_copy)>,
220+
std::remove_reference_t<decltype(b_copy)>>{
221+
std::move(a_copy), std::move(b_copy)};
222+
}
223+
224+
/**
225+
* Generates kernel code for this and nested expressions.
226+
* @param[in,out] generated set of (pointer to) already generated operations
227+
* @param name_gen name generator for this kernel
228+
* @param i row index variable name
229+
* @param j column index variable name
230+
* @param view_handled whether caller already handled matrix view
231+
* @return part of kernel with code for this and nested expressions
232+
*/
233+
inline kernel_parts get_kernel_parts(
234+
std::set<const operation_cl_base*>& generated, name_generator& name_gen,
235+
const std::string& i, const std::string& j, bool view_handled) const {
236+
kernel_parts res{};
237+
if (generated.count(this) == 0) {
238+
var_name = name_gen.generate();
239+
generated.insert(this);
240+
std::string j_b = "(" + j + " - " + var_name + "_first_cols)";
241+
kernel_parts parts_a = this->template get_arg<0>().get_kernel_parts(
242+
generated, name_gen, i, j, true);
243+
kernel_parts parts_b = this->template get_arg<1>().get_kernel_parts(
244+
generated, name_gen, i, j_b, true);
245+
res = parts_a + parts_b;
246+
res.body = type_str<Scalar>() + " " + var_name + ";\n"
247+
"if("+ j +" < " + var_name + "_first_cols){\n"
248+
+ parts_a.body +
249+
var_name + " = " + this->template get_arg<0>().var_name + ";\n"
250+
"} else{\n"
251+
+ parts_b.body +
252+
var_name + " = " + this->template get_arg<1>().var_name + ";\n"
253+
"}\n";
254+
res.args += "int " + var_name + "_first_cols, ";
255+
}
256+
return res;
257+
}
258+
259+
/**
260+
* Sets kernel arguments for this and nested expressions.
261+
* @param[in,out] generated set of expressions that already set their kernel
262+
* arguments
263+
* @param kernel kernel to set arguments on
264+
* @param[in,out] arg_num consecutive number of the first argument to set.
265+
* This is incremented for each argument set by this function.
266+
*/
267+
inline void set_args(std::set<const operation_cl_base*>& generated,
268+
cl::Kernel& kernel, int& arg_num) const {
269+
if (generated.count(this) == 0) {
270+
generated.insert(this);
271+
this->template get_arg<0>().set_args(generated, kernel, arg_num);
272+
this->template get_arg<1>().set_args(generated, kernel, arg_num);
273+
kernel.setArg(arg_num++, this->template get_arg<0>().cols());
274+
}
275+
}
276+
277+
/**
278+
* Number of rows of a matrix that would be the result of evaluating this
279+
* expression.
280+
* @return number of rows
281+
*/
282+
inline int cols() const {
283+
return this->template get_arg<0>().cols()
284+
+ this->template get_arg<1>().cols();
285+
}
286+
287+
/**
288+
* Determine indices of extreme sub- and superdiagonals written.
289+
* @return pair of indices - bottom and top diagonal
290+
*/
291+
inline std::pair<int, int> extreme_diagonals() const {
292+
std::pair<int, int> a_diags
293+
= this->template get_arg<0>().extreme_diagonals();
294+
std::pair<int, int> b_diags
295+
= this->template get_arg<1>().extreme_diagonals();
296+
return {a_diags.first, b_diags.second - this->template get_arg<0>().cols()};
297+
}
298+
};
299+
300+
/**
301+
* Stack the cols of the arguments.
302+
*
303+
* @tparam Ta type of first argument
304+
* @tparam Ta type of second argument
305+
* @param a First argument
306+
* @param b Second argument
307+
* @return Stacked arguments
308+
*/
309+
template <typename Ta, typename Tb,
310+
typename = require_all_valid_expressions_and_none_scalar_t<Ta, Tb>>
311+
inline auto append_col(Ta&& a, Tb&& b) {
312+
auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
313+
auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();
314+
return append_col_<std::remove_reference_t<decltype(a_operation)>,
315+
std::remove_reference_t<decltype(b_operation)>>(
316+
std::move(a_operation), std::move(b_operation));
317+
}
318+
319+
} // namespace math
320+
} // namespace stan
321+
322+
#endif
323+
#endif

0 commit comments

Comments
 (0)