Skip to content

Commit 1dee1dc

Browse files
authored
Merge pull request #2461 from stan-dev/feature/unary-var-matrix
Adds simple vectorized functions for var<matrix>
2 parents 6cf500e + 6f67a7f commit 1dee1dc

33 files changed

Lines changed: 462 additions & 19 deletions

stan/math/prim/fun/bessel_first_kind.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ inline T2 bessel_first_kind(int v, const T2 z) {
5151
* @param b Second input
5252
* @return Bessel first kind function applied to the two inputs.
5353
*/
54-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
54+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
55+
require_not_var_matrix_t<T2>* = nullptr>
5556
inline auto bessel_first_kind(const T1& a, const T2& b) {
5657
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
5758
return bessel_first_kind(c, d);

stan/math/prim/fun/beta.hpp

100755100644
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ inline return_type_t<T1, T2> beta(const T1 a, const T2 b) {
6565
* @param b Second input
6666
* @return Beta function applied to the two inputs.
6767
*/
68-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
68+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
69+
require_all_not_var_matrix_t<T1, T2>* = nullptr>
6970
inline auto beta(const T1& a, const T2& b) {
7071
return apply_scalar_binary(
7172
a, b, [&](const auto& c, const auto& d) { return beta(c, d); });

stan/math/prim/fun/binary_log_loss.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ inline T binary_log_loss(int y, const T& y_hat) {
4242
* @param b Second input
4343
* @return Binary log loss function applied to the two inputs.
4444
*/
45-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
45+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
46+
require_not_var_matrix_t<T2>* = nullptr>
4647
inline auto binary_log_loss(const T1& a, const T2& b) {
4748
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
4849
return binary_log_loss(c, d);

stan/math/prim/fun/ceil.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ inline auto ceil(const Container& x) {
5050
* @return Least integer >= each value in x.
5151
*/
5252
template <typename Container,
53-
require_container_st<std::is_arithmetic, Container>* = nullptr>
53+
require_container_st<std::is_arithmetic, Container>* = nullptr,
54+
require_not_var_matrix_t<Container>* = nullptr>
5455
inline auto ceil(const Container& x) {
5556
return apply_vector_unary<Container>::apply(
5657
x, [](const auto& v) { return v.array().ceil(); });

stan/math/prim/fun/erf.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ struct erf_fun {
3131
* @param x container
3232
* @return Error function applied to each value in x.
3333
*/
34-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
35-
T>* = nullptr>
34+
template <
35+
typename T,
36+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
37+
require_not_var_matrix_t<T>* = nullptr>
3638
inline auto erf(const T& x) {
3739
return apply_scalar_unary<erf_fun, T>::apply(x);
3840
}

stan/math/prim/fun/erfc.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ struct erfc_fun {
3232
* @param x container
3333
* @return Complementary error function applied to each value in x.
3434
*/
35-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
36-
T>* = nullptr>
35+
template <
36+
typename T,
37+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
38+
require_not_var_matrix_t<T>* = nullptr>
3739
inline auto erfc(const T& x) {
3840
return apply_scalar_unary<erfc_fun, T>::apply(x);
3941
}

stan/math/prim/fun/exp2.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ struct exp2_fun {
3535
* @param x container
3636
* @return Elementwise exp2 of members of container.
3737
*/
38-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
39-
T>* = nullptr>
38+
template <
39+
typename T,
40+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
41+
require_not_var_matrix_t<T>* = nullptr>
4042
inline auto exp2(const T& x) {
4143
return apply_scalar_unary<exp2_fun, T>::apply(x);
4244
}

stan/math/prim/fun/expm1.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ struct expm1_fun {
3232
* @param x container
3333
* @return Natural exponential of each value in x minus one.
3434
*/
35-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
36-
T>* = nullptr>
35+
template <
36+
typename T,
37+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
38+
require_not_var_matrix_t<T>* = nullptr>
3739
inline auto expm1(const T& x) {
3840
return apply_scalar_unary<expm1_fun, T>::apply(x);
3941
}

stan/math/prim/fun/falling_factorial.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ namespace math {
6262
*/
6363
template <typename T, require_arithmetic_t<T>* = nullptr>
6464
inline return_type_t<T> falling_factorial(const T& x, int n) {
65-
static const char* function = "falling_factorial";
65+
constexpr const char* function = "falling_factorial";
6666
check_not_nan(function, "first argument", x);
6767
check_nonnegative(function, "second argument", n);
6868
return boost::math::falling_factorial(x, n, boost_policy_t<>());
@@ -78,7 +78,8 @@ inline return_type_t<T> falling_factorial(const T& x, int n) {
7878
* @param b Second input
7979
* @return Falling factorial function applied to the two inputs.
8080
*/
81-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
81+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
82+
require_all_not_var_matrix_t<T1, T2>* = nullptr>
8283
inline auto falling_factorial(const T1& a, const T2& b) {
8384
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
8485
return falling_factorial(c, d);

stan/math/prim/fun/floor.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ struct floor_fun {
3636
template <typename Container,
3737
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
3838
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
39-
Container>* = nullptr>
39+
Container>* = nullptr,
40+
require_not_var_matrix_t<Container>* = nullptr>
4041
inline auto floor(const Container& x) {
4142
return apply_scalar_unary<floor_fun, Container>::apply(x);
4243
}
@@ -50,7 +51,8 @@ inline auto floor(const Container& x) {
5051
* @return Greatest integer <= each value in x.
5152
*/
5253
template <typename Container,
53-
require_container_st<std::is_arithmetic, Container>* = nullptr>
54+
require_container_st<std::is_arithmetic, Container>* = nullptr,
55+
require_not_var_matrix_t<Container>* = nullptr>
5456
inline auto floor(const Container& x) {
5557
return apply_vector_unary<Container>::apply(
5658
x, [](const auto& v) { return v.array().floor(); });

0 commit comments

Comments
 (0)