Skip to content

Commit 17a67ad

Browse files
rok-cesnovarstan-buildbotyashiknot4c1
authored
Feature/1854 OpenCL /prim signatures part 2 (#1869)
* added dims * finished add, rows, cols * scalar + matrix * simplify with tadej's suggestion * add col and tests and add the block boundary checks * add row and row tests * inv_sqrt done * inv_square * inv logit * add inv_cloglog * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * add inv() * [Jenkins] auto-formatting by clang-format version 6.0.0 * fix doxygen, cpplint * add header guards * revising the block checks, cpplint * [Jenkins] auto-formatting by clang-format version 6.0.0 * missing ifdef * newline * fix comments and use of x instead of a * expression is now kernel_expression * fix row/col * fix comments, change throw behaviour of block * fix block test * add row/col tests with expressions * [Jenkins] auto-formatting by clang-format version 6.0.0 * rename require * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tags/RELEASE_600/final) * rename require * rename require * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tags/RELEASE_600/final) * fix tests * add forward * is_valid_kernel_expression to is_kernel_expression * fix inv_square and inv_logit * cleanup after merge * replace with EXPECT_MATRIX_NEAR * [Jenkins] auto-formatting by clang-format version 6.0.0 * remove duplicate include * Update stan/math/opencl/prim/cols.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * Update stan/math/opencl/prim/dims.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * Update stan/math/opencl/prim/rows.hpp Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> * removed comment * fix rows/cols template * rename elewise functions * added crossprod and tcrossprod * adde fabs * use T_A of (t)crossprod * added logit * added log1m_inv_logit * add divide * add trunc, fix names of tests * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * newlines * fix merge * missing includes * address review comments * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * Apply suggestions from code review Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com> Co-authored-by: Stan Jenkins <mc.stanislaw@gmail.com> Co-authored-by: Jenkins <nobody@nowhere> Co-authored-by: Tadej Ciglarič <tadej.c@gmail.com>
1 parent ae09118 commit 17a67ad

19 files changed

Lines changed: 399 additions & 21 deletions

stan/math/opencl/kernel_generator/binary_operation.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ ADD_BINARY_OPERATION(addition_, operator+, common_scalar_t<T_a COMMA T_b>, "+");
187187
ADD_BINARY_OPERATION(subtraction_, operator-, common_scalar_t<T_a COMMA T_b>,
188188
"-");
189189
ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
190-
elewise_multiplication_, elewise_multiplication,
191-
common_scalar_t<T_a COMMA T_b>, "*",
190+
elt_multiply_, elt_multiply, common_scalar_t<T_a COMMA T_b>, "*",
192191
using view_transitivity = std::tuple<std::true_type, std::true_type>;
193192
inline std::pair<int, int> extreme_diagonals() const {
194193
std::pair<int, int> diags0
@@ -200,7 +199,7 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
200199
});
201200

202201
ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
203-
elewise_division_, elewise_division, common_scalar_t<T_a COMMA T_b>, "/",
202+
elt_divide_, elt_divide, common_scalar_t<T_a COMMA T_b>, "/",
204203
inline std::pair<int, int> extreme_diagonals() const {
205204
return {-rows() + 1, cols() - 1};
206205
});
@@ -246,7 +245,7 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
246245
*/
247246
template <typename T_a, typename T_b, typename = require_arithmetic_t<T_a>,
248247
typename = require_all_kernel_expressions_t<T_b>>
249-
inline elewise_multiplication_<scalar_<T_a>, as_operation_cl_t<T_b>> operator*(
248+
inline elt_multiply_<scalar_<T_a>, as_operation_cl_t<T_b>> operator*(
250249
T_a&& a, T_b&& b) { // NOLINT
251250
return {as_operation_cl(std::forward<T_a>(a)),
252251
as_operation_cl(std::forward<T_b>(b))};
@@ -263,7 +262,7 @@ inline elewise_multiplication_<scalar_<T_a>, as_operation_cl_t<T_b>> operator*(
263262
template <typename T_a, typename T_b,
264263
typename = require_all_kernel_expressions_t<T_a>,
265264
typename = require_arithmetic_t<T_b>>
266-
inline elewise_multiplication_<as_operation_cl_t<T_a>, scalar_<T_b>> operator*(
265+
inline elt_multiply_<as_operation_cl_t<T_a>, scalar_<T_b>> operator*(
267266
T_a&& a, const T_b b) { // NOLINT
268267
return {as_operation_cl(std::forward<T_a>(a)), as_operation_cl(b)};
269268
}

stan/math/opencl/kernel_generator/matrix_vector_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace math {
1515
template <typename T_matrix, typename T_vector,
1616
typename = require_all_kernel_expressions_t<T_matrix, T_vector>>
1717
inline auto matrix_vector_multiply(T_matrix&& matrix, T_vector&& vector) {
18-
return rowwise_sum(elewise_multiplication(
18+
return rowwise_sum(elt_multiply(
1919
std::forward<T_matrix>(matrix),
2020
colwise_broadcast(transpose(std::forward<T_vector>(vector)))));
2121
}

stan/math/opencl/kernel_generator/rowwise_reduction.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ struct matvec_mul_opt {
3939
};
4040

4141
template <typename Mat, typename VecT>
42-
struct matvec_mul_opt<
43-
elewise_multiplication_<Mat, broadcast_<VecT, true, false>>> {
42+
struct matvec_mul_opt<elt_multiply_<Mat, broadcast_<VecT, true, false>>> {
4443
// if the argument of rowwise reduction is multiplication with a broadcast
4544
// vector we can do the optimization
4645
enum { is_possible = 1 };
47-
using Arg = elewise_multiplication_<Mat, broadcast_<VecT, true, false>>;
46+
using Arg = elt_multiply_<Mat, broadcast_<VecT, true, false>>;
4847

4948
/**
5049
* Return view of the vector.

stan/math/opencl/kernel_generator/unary_function_cl.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include <stan/math/opencl/err.hpp>
77
#include <stan/math/opencl/kernels/device_functions/digamma.hpp>
88
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
9+
#include <stan/math/opencl/kernels/device_functions/log1m_inv_logit.hpp>
910
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
11+
#include <stan/math/opencl/kernels/device_functions/logit.hpp>
1012
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>
1113
#include <stan/math/opencl/kernels/device_functions/inv_square.hpp>
1214
#include <stan/math/opencl/matrix_cl_view.hpp>
@@ -227,6 +229,8 @@ ADD_UNARY_FUNCTION(erfc)
227229
ADD_UNARY_FUNCTION_PASS_ZERO(floor)
228230
ADD_UNARY_FUNCTION_PASS_ZERO(round)
229231
ADD_UNARY_FUNCTION_PASS_ZERO(ceil)
232+
ADD_UNARY_FUNCTION_PASS_ZERO(fabs)
233+
ADD_UNARY_FUNCTION_PASS_ZERO(trunc)
230234

231235
ADD_UNARY_FUNCTION_WITH_INCLUDE(digamma,
232236
opencl_kernels::digamma_device_function)
@@ -238,6 +242,9 @@ ADD_UNARY_FUNCTION_WITH_INCLUDE(inv_square,
238242
opencl_kernels::inv_square_device_function)
239243
ADD_UNARY_FUNCTION_WITH_INCLUDE(inv_logit,
240244
opencl_kernels::inv_logit_device_function)
245+
ADD_UNARY_FUNCTION_WITH_INCLUDE(logit, opencl_kernels::logit_device_function)
246+
ADD_UNARY_FUNCTION_WITH_INCLUDE(log1m_inv_logit,
247+
opencl_kernels::log1m_inv_logit_device_function)
241248

242249
ADD_CLASSIFICATION_FUNCTION(isfinite, {-rows() + 1, cols() - 1})
243250
ADD_CLASSIFICATION_FUNCTION(isinf,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOG1M_INV_LOGIT_HPP
2+
#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOG1M_INV_LOGIT_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/stringify.hpp>
6+
#include <string>
7+
8+
namespace stan {
9+
namespace math {
10+
namespace opencl_kernels {
11+
12+
// \cond
13+
static const char* log1m_inv_logit_device_function
14+
= "\n"
15+
"#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOG1M_INV_LOGIT\n"
16+
"#define "
17+
"STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOG1M_INV_LOGIT\n" STRINGIFY(
18+
// \endcond
19+
/** \ingroup opencl_kernels
20+
*
21+
* Return the the natural logarithm of 1 minus the inverse logit
22+
* applied to the kernel generator expression.
23+
*
24+
\f[
25+
\mbox{log1m\_inv\_logit}(x) =
26+
\begin{cases}
27+
-\ln(\exp(x)+1) & \mbox{if } -\infty\leq x \leq \infty \\[6pt]
28+
\textrm{NaN} & \mbox{if } x = \textrm{NaN}
29+
\end{cases}
30+
\f]
31+
32+
\f[
33+
\frac{\partial\, \mbox{log1m\_inv\_logit}(x)}{\partial x} =
34+
\begin{cases}
35+
-\frac{\exp(x)}{\exp(x)+1} & \mbox{if } -\infty\leq x\leq \infty
36+
\\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN} \end{cases} \f]
37+
*
38+
* @param x argument
39+
* @return log of one minus the inverse logit of the argument
40+
*/
41+
inline double log1m_inv_logit(double x) {
42+
if (x > 0.0) {
43+
return -x - log1p(exp(-x)); // prevent underflow
44+
}
45+
return -log1p(exp(x));
46+
}
47+
// \cond
48+
) "\n#endif\n"; // NOLINT
49+
// \endcond
50+
51+
} // namespace opencl_kernels
52+
} // namespace math
53+
} // namespace stan
54+
55+
#endif
56+
#endif
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOGIT_HPP
2+
#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOGIT_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/stringify.hpp>
6+
#include <string>
7+
8+
namespace stan {
9+
namespace math {
10+
namespace opencl_kernels {
11+
12+
// \cond
13+
static const char* logit_device_function
14+
= "\n"
15+
"#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOGIT\n"
16+
"#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_LOGIT\n" STRINGIFY(
17+
// \endcond
18+
/** \ingroup opencl_kernels
19+
*
20+
* Return the log odds applied to the kernel generator
21+
* expression.
22+
*
23+
* The logit function is defined as for \f$x \in [0, 1]\f$ by
24+
* returning the log odds of \f$x\f$ treated as a probability,
25+
*
26+
* \f$\mbox{logit}(x) = \log \left( \frac{x}{1 - x} \right)\f$.
27+
*
28+
* The inverse to this function is <code>inv_logit</code>.
29+
*
30+
*
31+
\f[
32+
\mbox{logit}(x) =
33+
\begin{cases}
34+
\textrm{NaN}& \mbox{if } x < 0 \textrm{ or } x > 1\\
35+
\ln\frac{x}{1-x} & \mbox{if } 0\leq x \leq 1 \\[6pt]
36+
\textrm{NaN} & \mbox{if } x = \textrm{NaN}
37+
\end{cases}
38+
\f]
39+
40+
\f[
41+
\frac{\partial\, \mbox{logit}(x)}{\partial x} =
42+
\begin{cases}
43+
\textrm{NaN}& \mbox{if } x < 0 \textrm{ or } x > 1\\
44+
\frac{1}{x-x^2}& \mbox{if } 0\leq x\leq 1 \\[6pt]
45+
\textrm{NaN} & \mbox{if } x = \textrm{NaN}
46+
\end{cases}
47+
\f]
48+
*
49+
* @param x argument
50+
* @return log odds of argument
51+
*/
52+
double logit(double x) { return log(x / (1 - x)); }
53+
// \cond
54+
) "\n#endif\n"; // NOLINT
55+
// \endcond
56+
57+
} // namespace opencl_kernels
58+
} // namespace math
59+
} // namespace stan
60+
61+
#endif
62+
#endif

stan/math/opencl/opencl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@
103103
#include <stan/math/opencl/prim/cholesky_decompose.hpp>
104104
#include <stan/math/opencl/prim/col.hpp>
105105
#include <stan/math/opencl/prim/cols.hpp>
106+
#include <stan/math/opencl/prim/crossprod.hpp>
106107
#include <stan/math/opencl/prim/dims.hpp>
108+
#include <stan/math/opencl/prim/divide.hpp>
107109
#include <stan/math/opencl/prim/divide_columns.hpp>
108110
#include <stan/math/opencl/prim/gp_exp_quad_cov.hpp>
109111
#include <stan/math/opencl/prim/inv.hpp>
@@ -120,6 +122,7 @@
120122
#include <stan/math/opencl/prim/rep_vector.hpp>
121123
#include <stan/math/opencl/prim/row.hpp>
122124
#include <stan/math/opencl/prim/rows.hpp>
125+
#include <stan/math/opencl/prim/tcrossprod.hpp>
123126

124127
#include <stan/math/opencl/err.hpp>
125128

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef STAN_MATH_OPENCL_PRIM_FUN_CROSSPROD_HPP
2+
#define STAN_MATH_OPENCL_PRIM_FUN_CROSSPROD_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/matrix_cl.hpp>
6+
#include <stan/math/opencl/multiply.hpp>
7+
#include <stan/math/opencl/kernel_generator.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
/**
12+
* Returns the result of pre-multiplying a matrix by its
13+
* own transpose.
14+
*
15+
* @tparam T type of elements in A
16+
* @param A input matrix
17+
* @return transpose(A) * A
18+
*/
19+
template <typename T_A,
20+
typename = require_all_kernel_expressions_and_none_scalar_t<T_A>>
21+
inline matrix_cl<typename std::decay_t<T_A>::Scalar> crossprod(T_A&& A) {
22+
const matrix_cl<typename std::decay_t<T_A>::Scalar>& A_eval
23+
= transpose(std::forward<T_A>(A));
24+
return multiply_transpose(A_eval);
25+
}
26+
} // namespace math
27+
} // namespace stan
28+
#endif
29+
#endif

stan/math/opencl/prim/divide.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef STAN_MATH_OPENCL_PRIM_DIVIDE_HPP
2+
#define STAN_MATH_OPENCL_PRIM_DIVIDE_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/matrix_cl.hpp>
6+
#include <stan/math/opencl/kernel_generator.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
/** \ingroup opencl
11+
* Returns the elementwise division of the kernel generator expression
12+
*
13+
* @tparam T_a type of input kernel generator expression a
14+
* @param a expression to divide
15+
* @param d scalar to divide by
16+
* @return the elements of expression a divided by d
17+
*/
18+
template <typename T_a,
19+
typename = require_all_kernel_expressions_and_none_scalar_t<T_a>>
20+
inline auto divide(T_a&& a, double d) { // NOLINT
21+
return elt_divide(std::forward<T_a>(a), d);
22+
}
23+
} // namespace math
24+
} // namespace stan
25+
#endif
26+
#endif

stan/math/opencl/prim/divide_columns.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ inline void divide_columns(const matrix_cl<T1>& A, const matrix_cl<T2>& B) {
5454
*/
5555
template <typename T1, typename T2, typename = require_all_arithmetic_t<T1, T2>>
5656
inline void divide_columns(const matrix_cl<T1>& A, const T2& divisor) {
57-
A = elewise_division(A, divisor);
57+
A = elt_divide(A, divisor);
5858
}
5959

6060
} // namespace math

0 commit comments

Comments
 (0)