Skip to content

Commit e917df0

Browse files
committed
combine definitions for vector of ints and single int for several of the unary functions
1 parent c73b899 commit e917df0

6 files changed

Lines changed: 26 additions & 70 deletions

File tree

stan/math/rev/fun/bessel_first_kind.hpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,14 @@ inline var bessel_first_kind(int v, const var& a) {
1818
});
1919
}
2020

21-
template <typename T, require_eigen_t<T>* = nullptr>
22-
inline auto bessel_first_kind(int v, const var_value<T>& a) {
21+
template <typename T1, typename T2,
22+
require_st_integral<T1>* = nullptr,
23+
require_eigen_t<T2>* = nullptr>
24+
inline auto bessel_first_kind(const T1& v, const var_value<T2>& a) {
2325
auto ret_val = bessel_first_kind(v, a.val()).array().eval();
24-
auto precomp_bessel = to_arena(v * ret_val / a.val().array()
25-
- bessel_first_kind(v + 1, a.val()).array());
26-
return make_callback_var(
27-
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
28-
a.adj().array() += vi.adj().array() * precomp_bessel;
29-
});
30-
}
31-
32-
template <typename T, require_eigen_t<T>* = nullptr>
33-
inline auto bessel_first_kind(const std::vector<int>& v,
34-
const var_value<T>& a) {
35-
auto ret_val = bessel_first_kind(v, a.val()).array().eval();
36-
Eigen::Map<const Eigen::Array<int, -1, 1>> v_map(v.data(), v.size());
37-
auto precomp_bessel
38-
= to_arena(v_map.template cast<double>() * ret_val / a.val().array()
39-
- bessel_first_kind(v_map + 1, a.val().array()));
26+
auto v_map = as_array_or_scalar(v);
27+
auto precomp_bessel = to_arena(v_map * ret_val / a.val().array()
28+
- bessel_first_kind(v_map + 1, a.val().array()));
4029
return make_callback_var(
4130
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
4231
a.adj().array() += vi.adj().array() * precomp_bessel;

stan/math/rev/fun/bessel_second_kind.hpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,14 @@ inline var bessel_second_kind(int v, const var& a) {
1717
});
1818
}
1919

20-
template <typename T, require_eigen_t<T>* = nullptr>
21-
inline auto bessel_second_kind(int v, const var_value<T>& a) {
20+
template <typename T1, typename T2,
21+
require_st_integral<T1>* = nullptr,
22+
require_eigen_t<T2>* = nullptr>
23+
inline auto bessel_second_kind(const T1& v, const var_value<T2>& a) {
2224
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
23-
auto precomp_bessel = to_arena(v * ret_val / a.val().array()
24-
- bessel_second_kind(v + 1, a.val()).array());
25-
return make_callback_var(
26-
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
27-
a.adj().array() += vi.adj().array() * precomp_bessel;
28-
});
29-
}
30-
31-
template <typename T, require_eigen_t<T>* = nullptr>
32-
inline auto bessel_second_kind(const std::vector<int>& v,
33-
const var_value<T>& a) {
34-
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
35-
Eigen::Map<const Eigen::Array<int, -1, 1>> v_map(v.data(), v.size());
36-
auto precomp_bessel
37-
= to_arena(v_map.template cast<double>() * ret_val / a.val().array()
38-
- bessel_second_kind(v_map + 1, a.val().array()));
25+
auto v_map = as_array_or_scalar(v);
26+
auto precomp_bessel = to_arena(v_map * ret_val / a.val().array()
27+
- bessel_second_kind(v_map + 1, a.val().array()));
3928
return make_callback_var(
4029
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
4130
a.adj().array() += vi.adj().array() * precomp_bessel;

stan/math/rev/fun/binary_log_loss.hpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,6 @@ inline auto binary_log_loss(const std::vector<int>& y,
8585
/ (arena_y == 0)
8686
.select((1.0 - y_hat.val().array()), -y_hat.val().array());
8787
});
88-
/*
89-
if (y == 0) {
90-
return make_callback_var(-log1p(-y_hat.val()), [y_hat](auto& vi) mutable {
91-
y_hat.adj().array() += vi.adj().array() / (1.0 - y_hat.val().array());
92-
});
93-
} else {
94-
return make_callback_var(-std::log(y_hat.val()), [y_hat](auto& vi) mutable {
95-
y_hat.adj().array() -= vi.adj().array() / y_hat.val().array();
96-
});
97-
}
98-
*/
9988
}
10089

10190
} // namespace math

stan/math/rev/fun/exp2.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6+
#include <stan/math/prim/fun/exp2.hpp>
67
#include <stan/math/prim/fun/constants.hpp>
78
#include <cmath>
89

@@ -43,8 +44,7 @@ inline var exp2(const var& a) {
4344

4445
template <typename T, require_eigen_t<T>* = nullptr>
4546
inline auto exp2(const var_value<T>& a) {
46-
return make_callback_var(
47-
a.val().unaryExpr([](auto&& x) { return std::exp2(x); }),
47+
return make_callback_var(exp2(a.val()),
4848
[a](auto& vi) mutable {
4949
a.adj().array() += vi.adj().array() * vi.val().array() * LOG_TWO;
5050
});

stan/math/rev/fun/expm1.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ inline var expm1(const var& a) {
4141
});
4242
}
4343

44-
template <typename T, require_matrix_t<T>* = nullptr>
44+
template <typename T, require_eigen_t<T>* = nullptr>
4545
inline auto expm1(const var_value<T>& a) {
46-
return make_callback_var(expm1(a.val()).eval(), [a](auto& vi) mutable {
46+
return make_callback_var(expm1(a.val()), [a](auto& vi) mutable {
4747
a.adj().array() += vi.adj().array() * (vi.val().array() + 1.0);
4848
});
4949
}

stan/math/rev/fun/falling_factorial.hpp

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,16 @@ inline var falling_factorial(const var& a, int b) {
1717
});
1818
}
1919

20-
template <typename T, require_eigen_t<T>* = nullptr>
21-
inline auto falling_factorial(const var_value<T>& a, int b) {
22-
auto digamma_ab = to_arena(digamma(a.val().array() + 1)
23-
- digamma(a.val().array() - b + 1));
24-
return make_callback_var(
25-
falling_factorial(a.val(), b), [a, digamma_ab](auto& vi) mutable {
26-
a.adj().array() += vi.adj().array() * vi.val().array() * digamma_ab;
27-
});
28-
}
29-
30-
template <typename T, typename StdVec, require_eigen_t<T>* = nullptr,
31-
require_vector_like_vt<std::is_integral, StdVec>* = nullptr>
32-
inline auto falling_factorial(const var_value<T>& a, const StdVec& b) {
33-
Eigen::Array<int, -1, 1> b_map
34-
= Eigen::Map<const Eigen::Array<int, -1, 1>>(b.data(), b.size());
20+
template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
21+
require_st_integral<T2>* = nullptr>
22+
inline auto falling_factorial(const var_value<T1>& a, const T2& b) {
23+
auto b_map = as_array_or_scalar(b);
3524
auto digamma_ab = to_arena(digamma(a.val().array() + 1)
3625
- digamma(a.val().array() - b_map + 1));
37-
return make_callback_var(
38-
falling_factorial(a.val(), b), [a, digamma_ab](auto& vi) mutable {
39-
a.adj().array() += vi.adj().array() * vi.val().array() * digamma_ab;
40-
});
26+
return make_callback_var(
27+
falling_factorial(a.val(), b), [a, digamma_ab](auto& vi) mutable {
28+
a.adj().array() += vi.adj().array() * vi.val().array() * digamma_ab;
29+
});
4130
}
4231

4332
} // namespace math

0 commit comments

Comments
 (0)