Skip to content

Commit b994475

Browse files
authored
Merge pull request #2472 from bstatcomp/kg_cast
Add kernel generator cast operation
2 parents 1dee1dc + fcaff77 commit b994475

32 files changed

Lines changed: 250 additions & 88 deletions

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
#include <stan/math/opencl/kernel_generator/index.hpp>
132132
#include <stan/math/opencl/kernel_generator/indexing.hpp>
133133
#include <stan/math/opencl/kernel_generator/opencl_code.hpp>
134+
#include <stan/math/opencl/kernel_generator/cast.hpp>
134135

135136
#include <stan/math/opencl/kernel_generator/multi_result_kernel.hpp>
136137
#include <stan/math/opencl/kernel_generator/get_kernel_source_for_evaluating_into.hpp>
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_CAST_HPP
2+
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_CAST_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/opencl/matrix_cl_view.hpp>
7+
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp>
8+
#include <stan/math/opencl/kernel_generator/type_str.hpp>
9+
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
10+
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
11+
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
12+
#include <array>
13+
#include <string>
14+
#include <type_traits>
15+
#include <set>
16+
#include <utility>
17+
18+
namespace stan {
19+
namespace math {
20+
21+
/** \addtogroup opencl_kernel_generator
22+
* @{
23+
*/
24+
25+
/**
26+
* Represents a typecast os scalar in kernel generator expressions.
27+
* @tparam Derived derived type
28+
* @tparam T type of argument
29+
* @tparam Scal type of the scalar of result
30+
*/
31+
template <typename Scal, typename T>
32+
class cast_ : public operation_cl<cast_<Scal, T>, Scal, T> {
33+
public:
34+
using Scalar = Scal;
35+
using base = operation_cl<cast_<Scal, T>, Scalar, T>;
36+
using base::var_name_;
37+
38+
/**
39+
* Constructor
40+
* @param args argument expression(s)
41+
*/
42+
explicit cast_(T&& arg) : base(std::forward<T>(arg)) {}
43+
44+
/**
45+
* Generates kernel code for this expression.
46+
* @param row_index_name row index variable name
47+
* @param col_index_name column index variable name
48+
* @param view_handled whether whether caller already handled matrix view
49+
* @param var_names_arg variable names of the nested expressions
50+
* @return part of kernel with code for this expression
51+
*/
52+
inline kernel_parts generate(const std::string& row_index_name,
53+
const std::string& col_index_name,
54+
const bool view_handled,
55+
const std::string& var_name_arg) const {
56+
kernel_parts res{};
57+
58+
res.body = type_str<Scalar>() + " " + var_name_ + " = ("
59+
+ type_str<Scalar>() + ")" + var_name_arg + ";\n";
60+
return res;
61+
}
62+
63+
inline auto deep_copy() const {
64+
auto&& arg_copy = this->template get_arg<0>().deep_copy();
65+
return cast_<Scalar, std::remove_reference_t<decltype(arg_copy)>>{
66+
std::move(arg_copy)};
67+
}
68+
};
69+
70+
/**
71+
* Typecast a kernel generator expression scalar.
72+
*
73+
* @tparam T type of argument
74+
* @param a input argument
75+
* @return Typecast of given expression
76+
*/
77+
template <typename Scalar, typename T,
78+
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
79+
inline auto cast(T&& a) {
80+
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
81+
return cast_<Scalar, std::remove_reference_t<decltype(a_operation)>>(
82+
std::move(a_operation));
83+
}
84+
85+
/**
86+
* Typecast a scalar.
87+
*
88+
* @tparam T type of argument
89+
* @param a input argument
90+
* @return Typecast of given expression
91+
*/
92+
template <typename Scalar, typename T, require_stan_scalar_t<T>* = nullptr>
93+
inline Scalar cast(T a) {
94+
return a;
95+
}
96+
97+
/** @}*/
98+
} // namespace math
99+
} // namespace stan
100+
#endif
101+
#endif

stan/math/opencl/prim/bernoulli_cdf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ return_type_t<T_prob_cl> bernoulli_cdf(const T_n_cl& n,
4949
theta_val, "in the interval [0, 1]");
5050
auto theta_bounded_expr = 0.0 <= theta_val && theta_val <= 1.0;
5151

52-
auto any_n_negative = colwise_max(constant(0, N, 1) + (n < 0));
52+
auto any_n_negative = colwise_max(cast<char>(n < 0));
5353
auto cond = n >= 1;
5454
auto Pi_uncond = 1.0 - theta_val;
5555
auto Pi = select(cond, INFTY, Pi_uncond);
5656
auto P_expr = colwise_prod(select(cond, 1.0, Pi_uncond));
5757

58-
matrix_cl<double> any_n_negative_cl;
58+
matrix_cl<char> any_n_negative_cl;
5959
matrix_cl<double> Pi_cl;
6060
matrix_cl<double> P_cl;
6161

stan/math/opencl/prim/bernoulli_lccdf.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ return_type_t<T_prob_cl> bernoulli_lccdf(const T_n_cl& n,
5050
theta_val, "in the interval [0, 1]");
5151
auto theta_bounded_expr = 0.0 <= theta_val && theta_val <= 1.0;
5252

53-
auto any_n_negative = colwise_max(0 + (n < 0));
54-
auto any_n_over_one = colwise_max(constant(0, N, 1) + (n >= 1));
53+
auto any_n_negative = colwise_max(cast<char>(n < 0));
54+
auto any_n_over_one = colwise_max(cast<char>(n >= 1));
5555
auto P_expr = colwise_sum(log(theta_val));
5656
auto deriv = elt_divide(1.0, theta_val);
5757

58-
matrix_cl<double> any_n_negative_cl;
59-
matrix_cl<double> any_n_over_one_cl;
58+
matrix_cl<char> any_n_negative_cl;
59+
matrix_cl<char> any_n_over_one_cl;
6060
matrix_cl<double> P_cl;
6161
matrix_cl<double> deriv_cl;
6262

stan/math/opencl/prim/bernoulli_lcdf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ return_type_t<T_prob_cl> bernoulli_lcdf(const T_n_cl& n,
5050
theta_val, "in the interval [0, 1]");
5151
auto theta_bounded_expr = 0.0 <= theta_val && theta_val <= 1.0;
5252

53-
auto any_n_negative = colwise_max(0 + (n < 0));
53+
auto any_n_negative = colwise_max(cast<char>(n < 0));
5454
auto Pi = 1.0 - theta_val;
5555
auto cond = n >= 1;
5656
auto P_expr = colwise_sum(select(cond, 0.0, log(Pi)));
5757
auto deriv = select(cond, 0.0, elt_divide(-1.0, Pi));
5858

59-
matrix_cl<double> any_n_negative_cl;
59+
matrix_cl<char> any_n_negative_cl;
6060
matrix_cl<double> P_cl;
6161
matrix_cl<double> deriv_cl;
6262

stan/math/opencl/prim/cauchy_cdf.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> cauchy_cdf(
6363
= check_cl(function, "Scale parameter", sigma_val, "positive finite");
6464
auto sigma_positive_finite_expr = 0 < sigma_val && isfinite(sigma_val);
6565

66-
auto any_y_neg_inf
67-
= colwise_max(constant(0, N, 1) + (y_val == NEGATIVE_INFTY));
66+
auto any_y_neg_inf = colwise_max(cast<char>(y_val == NEGATIVE_INFTY));
6867
auto cond = y_val == INFTY;
6968
auto sigma_inv = elt_divide(1.0, sigma_val);
7069
auto z = elt_multiply(y_val - mu_val, sigma_inv);
@@ -76,7 +75,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> cauchy_cdf(
7675
elt_divide(sigma_inv, -pi() * elt_multiply(1.0 + square(z), Pn)));
7776
auto sigma_deriv_tmp = elt_multiply(z, mu_deriv_tmp);
7877

79-
matrix_cl<double> any_y_neg_inf_cl;
78+
matrix_cl<char> any_y_neg_inf_cl;
8079
matrix_cl<double> P_cl;
8180
matrix_cl<double> mu_deriv_cl;
8281
matrix_cl<double> y_deriv_cl;

stan/math/opencl/prim/exp_mod_normal_cdf.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl, T_inv_scale_cl> exp_mod_normal_cdf(
7272
= check_cl(function, "Inv_cale parameter", lambda_val, "positive finite");
7373
auto lambda_positive_finite_expr = 0 < lambda_val && isfinite(lambda_val);
7474

75-
auto any_y_neg_inf
76-
= colwise_max(constant(0, N, 1) + (y_val == NEGATIVE_INFTY));
75+
auto any_y_neg_inf = colwise_max(cast<char>(y_val == NEGATIVE_INFTY));
7776
auto inv_sigma = elt_divide(1.0, sigma_val);
7877
auto diff = y_val - mu_val;
7978
auto v = elt_multiply(lambda_val, sigma_val);
@@ -102,7 +101,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl, T_inv_scale_cl> exp_mod_normal_cdf(
102101
- elt_multiply(elt_multiply(v, sigma_val) - diff, erf_calc)),
103102
cdf_n);
104103

105-
matrix_cl<double> any_y_neg_inf_cl;
104+
matrix_cl<char> any_y_neg_inf_cl;
106105
matrix_cl<double> cdf_cl;
107106
matrix_cl<double> y_deriv_cl;
108107
matrix_cl<double> mu_deriv_cl;

stan/math/opencl/prim/exp_mod_normal_lccdf.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ exp_mod_normal_lccdf(const T_y_cl& y, const T_loc_cl& mu,
7373
= check_cl(function, "Inv_cale parameter", lambda_val, "positive finite");
7474
auto lambda_positive_finite_expr = 0 < lambda_val && isfinite(lambda_val);
7575

76-
auto any_y_neg_inf
77-
= colwise_max(constant(0, N, 1) + (y_val == NEGATIVE_INFTY));
78-
auto any_y_pos_inf = colwise_max(constant(0, N, 1) + (y_val == INFTY));
76+
auto any_y_neg_inf = colwise_max(cast<char>(y_val == NEGATIVE_INFTY));
77+
auto any_y_pos_inf = colwise_max(cast<char>(y_val == INFTY));
7978
auto inv_sigma = elt_divide(1.0, sigma_val);
8079
auto diff = y_val - mu_val;
8180
auto scaled_diff = elt_multiply(diff, inv_sigma * INV_SQRT_TWO);
@@ -104,8 +103,8 @@ exp_mod_normal_lccdf(const T_y_cl& y, const T_loc_cl& mu,
104103
- INV_SQRT_TWO_PI * elt_multiply(sigma_val, exp_term_2)),
105104
ccdf_n);
106105

107-
matrix_cl<double> any_y_neg_inf_cl;
108-
matrix_cl<double> any_y_pos_inf_cl;
106+
matrix_cl<char> any_y_neg_inf_cl;
107+
matrix_cl<char> any_y_pos_inf_cl;
109108
matrix_cl<double> ccdf_log_cl;
110109
matrix_cl<double> mu_deriv_cl;
111110
matrix_cl<double> y_deriv_cl;

stan/math/opencl/prim/exp_mod_normal_lcdf.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl, T_inv_scale_cl> exp_mod_normal_lcdf(
7373
= check_cl(function, "Inv_cale parameter", lambda_val, "positive finite");
7474
auto lambda_positive_finite_expr = 0 < lambda_val && isfinite(lambda_val);
7575

76-
auto any_y_neg_inf
77-
= colwise_max(constant(0, N, 1) + (y_val == NEGATIVE_INFTY));
78-
auto any_y_pos_inf = colwise_max(constant(0, N, 1) + (y_val == INFTY));
76+
auto any_y_neg_inf = colwise_max(cast<char>(y_val == NEGATIVE_INFTY));
77+
auto any_y_pos_inf = colwise_max(cast<char>(y_val == INFTY));
7978
auto sigma_inv = elt_divide(1.0, sigma_val);
8079
auto diff = y_val - mu_val;
8180
auto scaled_diff = elt_multiply(diff * INV_SQRT_TWO, sigma_inv);
@@ -105,8 +104,8 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl, T_inv_scale_cl> exp_mod_normal_lcdf(
105104
- elt_multiply(elt_multiply(v, sigma_val) - diff, erf_calc)),
106105
cdf_n);
107106

108-
matrix_cl<double> any_y_neg_inf_cl;
109-
matrix_cl<double> any_y_pos_inf_cl;
107+
matrix_cl<char> any_y_neg_inf_cl;
108+
matrix_cl<char> any_y_pos_inf_cl;
110109
matrix_cl<double> cdf_log_cl;
111110
matrix_cl<double> mu_deriv_cl;
112111
matrix_cl<double> y_deriv_cl;

stan/math/opencl/prim/gamma_lpdf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ return_type_t<T_y_cl, T_shape_cl, T_inv_scale_cl> gamma_lpdf(
8080
beta_val, "positive finite");
8181
auto beta_pos_finite_expr = beta_val > 0 && isfinite(beta_val);
8282

83-
auto any_y_negative_expr = colwise_max(constant(0, N, 1) + (y_val < 0));
83+
auto any_y_negative_expr = colwise_max(cast<char>(y_val < 0));
8484
auto log_y_expr = log(y_val);
8585
auto log_beta_expr = log(beta_val);
8686
auto logp1_expr = static_select<include_summand<propto, T_shape_cl>::value>(
@@ -99,7 +99,7 @@ return_type_t<T_y_cl, T_shape_cl, T_inv_scale_cl> gamma_lpdf(
9999
auto alpha_deriv_expr = log_beta_expr + log_y_expr - digamma(alpha_val);
100100
auto beta_deriv_expr = elt_divide(alpha_val, beta_val) - y_val;
101101

102-
matrix_cl<int> any_y_negative_cl;
102+
matrix_cl<char> any_y_negative_cl;
103103
matrix_cl<double> logp_cl;
104104
matrix_cl<double> y_deriv_cl;
105105
matrix_cl<double> alpha_deriv_cl;

0 commit comments

Comments
 (0)