Skip to content

Commit 3f3fa3d

Browse files
rok-cesnovarstan-buildbotyashiknot4c1
authored
Feature/1854 OpenCL /prim signatures part 1 (#1859)
* added dims * finished add, rows, cols * scalar + matrix * simplify with tadej's suggestion * add col and tests and add the block boundary checks * add row and row tests * inv_sqrt done * inv_square * inv logit * add inv_cloglog * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * add inv() * [Jenkins] auto-formatting by clang-format version 6.0.0 * fix doxygen, cpplint * add header guards * revising the block checks, cpplint * [Jenkins] auto-formatting by clang-format version 6.0.0 * missing ifdef * newline * fix comments and use of x instead of a * expression is now kernel_expression * fix row/col * fix comments, change throw behaviour of block * fix block test * add row/col tests with expressions * [Jenkins] auto-formatting by clang-format version 6.0.0 * rename require * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tags/RELEASE_600/final) * rename require * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tags/RELEASE_600/final) * fix tests * add forward * is_valid_kernel_expression to is_kernel_expression * fix inv_square and inv_logit * cleanup after merge * replace with EXPECT_MATRIX_NEAR * [Jenkins] auto-formatting by clang-format version 6.0.0 * remove duplicate include * Update stan/math/opencl/prim/cols.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * Update stan/math/opencl/prim/dims.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * Update stan/math/opencl/prim/rows.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * removed comment * fix rows/cols template Co-authored-by: Stan Jenkins <mc.stanislaw@gmail.com> Co-authored-by: Jenkins <nobody@nowhere> Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com>
1 parent fdf7f70 commit 3f3fa3d

54 files changed

Lines changed: 1149 additions & 215 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
110110
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
111111
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
112-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
112+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
113113
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
114114
#include <stan/math/opencl/kernel_generator/type_str.hpp>
115115

stan/math/opencl/kernel_generator/append.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
1111
#include <stan/math/opencl/kernel_generator/scalar.hpp>
1212
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
13-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
13+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1414
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp>
1515
#include <algorithm>
1616
#include <string>
@@ -162,7 +162,7 @@ class append_row_ : public operation_cl<append_row_<T_a, T_b>,
162162
* @return Stacked arguments
163163
*/
164164
template <typename Ta, typename Tb,
165-
typename = require_all_valid_expressions_and_none_scalar_t<Ta, Tb>>
165+
typename = require_all_kernel_expressions_and_none_scalar_t<Ta, Tb>>
166166
inline auto append_row(Ta&& a, Tb&& b) {
167167
auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
168168
auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();
@@ -307,7 +307,7 @@ class append_col_ : public operation_cl<append_col_<T_a, T_b>,
307307
* @return Stacked arguments
308308
*/
309309
template <typename Ta, typename Tb,
310-
typename = require_all_valid_expressions_and_none_scalar_t<Ta, Tb>>
310+
typename = require_all_kernel_expressions_and_none_scalar_t<Ta, Tb>>
311311
inline auto append_col(Ta&& a, Tb&& b) {
312312
auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
313313
auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();

stan/math/opencl/kernel_generator/binary_operation.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
1111
#include <stan/math/opencl/kernel_generator/scalar.hpp>
1212
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
13-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
13+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1414
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp>
1515
#include <algorithm>
1616
#include <string>
@@ -126,7 +126,7 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
126126
}; \
127127
\
128128
template <typename T_a, typename T_b, \
129-
typename = require_all_valid_expressions_t<T_a, T_b>> \
129+
typename = require_all_kernel_expressions_t<T_a, T_b>> \
130130
inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
131131
function_name(T_a&& a, T_b&& b) { /* NOLINT */ \
132132
return {as_operation_cl(std::forward<T_a>(a)), \
@@ -176,7 +176,7 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
176176
}; \
177177
\
178178
template <typename T_a, typename T_b, \
179-
typename = require_all_valid_expressions_t<T_a, T_b>> \
179+
typename = require_all_kernel_expressions_t<T_a, T_b>> \
180180
inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
181181
function_name(T_a&& a, T_b&& b) { /* NOLINT */ \
182182
return {as_operation_cl(std::forward<T_a>(a)), \
@@ -245,7 +245,7 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
245245
* @return Multiplication of given arguments
246246
*/
247247
template <typename T_a, typename T_b, typename = require_arithmetic_t<T_a>,
248-
typename = require_all_valid_expressions_t<T_b>>
248+
typename = require_all_kernel_expressions_t<T_b>>
249249
inline elewise_multiplication_<scalar_<T_a>, as_operation_cl_t<T_b>> operator*(
250250
T_a&& a, T_b&& b) { // NOLINT
251251
return {as_operation_cl(std::forward<T_a>(a)),
@@ -261,7 +261,7 @@ inline elewise_multiplication_<scalar_<T_a>, as_operation_cl_t<T_b>> operator*(
261261
* @return Multiplication of given arguments
262262
*/
263263
template <typename T_a, typename T_b,
264-
typename = require_all_valid_expressions_t<T_a>,
264+
typename = require_all_kernel_expressions_t<T_a>,
265265
typename = require_arithmetic_t<T_b>>
266266
inline elewise_multiplication_<as_operation_cl_t<T_a>, scalar_<T_b>> operator*(
267267
T_a&& a, const T_b b) { // NOLINT

stan/math/opencl/kernel_generator/block.hpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
1010
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
1111
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
12-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
12+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1313
#include <set>
1414
#include <string>
1515
#include <tuple>
@@ -57,9 +57,25 @@ class block_
5757
start_col_(start_col),
5858
rows_(rows),
5959
cols_(cols) {
60+
if (start_col < 0) {
61+
invalid_argument("block", "start_col", start_col,
62+
" should be non-negative, but is ");
63+
}
64+
if (start_row < 0) {
65+
invalid_argument("block", "start_row", start_row,
66+
" should be non-negative, but is ");
67+
}
68+
if (rows < 0) {
69+
invalid_argument("block", "rows", rows,
70+
" should be non-negative, but is ");
71+
}
72+
if (cols < 0) {
73+
invalid_argument("block", "cols", cols,
74+
" should be non-negative, but is ");
75+
}
6076
if ((a.rows() != base::dynamic && (start_row + rows) > a.rows())
6177
|| (a.cols() != base::dynamic && (start_col + cols) > a.cols())) {
62-
throw_domain_error("block", "block of \"a\"", " is out of bounds", "");
78+
invalid_argument("block", "block of \"a\"", " is out of bounds", "");
6379
}
6480
}
6581

@@ -216,7 +232,7 @@ class block_
216232
* @return Block of given expression
217233
*/
218234
template <typename T,
219-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
235+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
220236
inline auto block(T&& a, int start_row, int start_col, int rows, int cols) {
221237
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
222238
return block_<std::remove_reference_t<decltype(a_operation)>>(

stan/math/opencl/kernel_generator/broadcast.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
1010
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
11-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
11+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1212
#include <limits>
1313
#include <string>
1414
#include <type_traits>
@@ -138,7 +138,7 @@ class broadcast_
138138
* @return broadcast expression
139139
*/
140140
template <bool Colwise, bool Rowwise, typename T,
141-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
141+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
142142
inline broadcast_<as_operation_cl_t<T>, Colwise, Rowwise> broadcast(T&& a) {
143143
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
144144
return broadcast_<as_operation_cl_t<T>, Colwise, Rowwise>(
@@ -158,7 +158,7 @@ inline broadcast_<as_operation_cl_t<T>, Colwise, Rowwise> broadcast(T&& a) {
158158
* @return broadcast expression
159159
*/
160160
template <typename T,
161-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
161+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
162162
inline auto rowwise_broadcast(T&& a) {
163163
return broadcast<false, true>(std::forward<T>(a));
164164
}
@@ -176,7 +176,7 @@ inline auto rowwise_broadcast(T&& a) {
176176
* @return broadcast expression
177177
*/
178178
template <typename T,
179-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
179+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
180180
inline auto colwise_broadcast(T&& a) {
181181
return broadcast<true, false>(std::forward<T>(a));
182182
}

stan/math/opencl/kernel_generator/calc_if.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
1010
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
11-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
11+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1212
#include <string>
1313
#include <type_traits>
1414
#include <set>
@@ -91,7 +91,7 @@ class calc_if_
9191
};
9292

9393
template <bool Do_Calculate, typename T,
94-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
94+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
9595
inline calc_if_<Do_Calculate, as_operation_cl_t<T>> calc_if(T&& a) {
9696
return calc_if_<Do_Calculate, as_operation_cl_t<T>>(
9797
as_operation_cl(std::forward<T>(a)));

stan/math/opencl/kernel_generator/colwise_reduction.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
1010
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
1111
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
12-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
12+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1313
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
1414
#include <set>
1515
#include <string>
@@ -181,7 +181,7 @@ class colwise_sum_ : public colwise_reduction<colwise_sum_<T>, T, sum_op> {
181181
* @return sum
182182
*/
183183
template <typename T,
184-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
184+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
185185
inline auto colwise_sum(T&& a) {
186186
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
187187
return colwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
@@ -229,7 +229,7 @@ class colwise_max_ : public colwise_reduction<
229229
* @return max
230230
*/
231231
template <typename T,
232-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
232+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
233233
inline auto colwise_max(T&& a) {
234234
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
235235
return colwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
@@ -277,7 +277,7 @@ class colwise_min_ : public colwise_reduction<
277277
* @return min
278278
*/
279279
template <typename T,
280-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
280+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
281281
inline auto colwise_min(T&& a) {
282282
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
283283
return colwise_min_<std::remove_reference_t<decltype(arg_copy)>>(

stan/math/opencl/kernel_generator/diagonal.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
1010
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
11-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
11+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
1212
#include <algorithm>
1313
#include <set>
1414
#include <string>
@@ -166,7 +166,7 @@ class diagonal_
166166
* @return Diagonal of given expression
167167
*/
168168
template <typename T,
169-
typename = require_all_valid_expressions_and_none_scalar_t<T>>
169+
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
170170
inline auto diagonal(T&& a) {
171171
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
172172
return diagonal_<std::remove_reference_t<decltype(a_operation)>>(

stan/math/opencl/kernel_generator/evaluate_into.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
77
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
8-
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
8+
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
99
#include <stan/math/opencl/kernel_generator/multi_result_kernel.hpp>
1010
#include <CL/cl2.hpp>
1111
#include <string>
@@ -21,7 +21,7 @@ template <typename Derived, typename Scalar, typename... Args>
2121
template <typename T_lhs>
2222
void operation_cl<Derived, Scalar, Args...>::evaluate_into(T_lhs& lhs) const {
2323
static_assert(
24-
is_valid_expression<T_lhs>::value,
24+
is_kernel_expression<T_lhs>::value,
2525
"operation_cl::evaluate_into: left hand side is not a valid expression!");
2626
results(lhs) = expressions(derived());
2727
}

stan/math/opencl/kernel_generator/get_kernel_source_for_evaluating_into.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::string
2121
operation_cl<Derived, Scalar, Args...>::get_kernel_source_for_evaluating_into(
2222
const T_lhs& lhs) const {
2323
static_assert(
24-
is_valid_expression<T_lhs>::value,
24+
is_kernel_expression<T_lhs>::value,
2525
"operation_cl::get_kernel_source_for_evaluating_into: left hand "
2626
"side is not a valid expression!");
2727
return results(lhs).get_kernel_source_for_evaluating(expressions(derived()));

0 commit comments

Comments
 (0)