@@ -17,27 +17,16 @@ inline var falling_factorial(const var& a, int b) {
1717 });
1818}
1919
20- template <typename T, require_eigen_t <T>* = nullptr >
21- inline auto falling_factorial (const var_value<T>& a, int b) {
22- auto digamma_ab = to_arena (digamma (a.val ().array () + 1 )
23- - digamma (a.val ().array () - b + 1 ));
24- return make_callback_var (
25- falling_factorial (a.val (), b), [a, digamma_ab](auto & vi) mutable {
26- a.adj ().array () += vi.adj ().array () * vi.val ().array () * digamma_ab;
27- });
28- }
29-
30- template <typename T, typename StdVec, require_eigen_t <T>* = nullptr ,
31- require_vector_like_vt<std::is_integral, StdVec>* = nullptr >
32- inline auto falling_factorial (const var_value<T>& a, const StdVec& b) {
33- Eigen::Array<int , -1 , 1 > b_map
34- = Eigen::Map<const Eigen::Array<int , -1 , 1 >>(b.data (), b.size ());
20+ template <typename T1, typename T2, require_eigen_t <T1>* = nullptr ,
21+ require_st_integral<T2>* = nullptr >
22+ inline auto falling_factorial (const var_value<T1>& a, const T2& b) {
23+ auto b_map = as_array_or_scalar (b);
3524 auto digamma_ab = to_arena (digamma (a.val ().array () + 1 )
3625 - digamma (a.val ().array () - b_map + 1 ));
37- return make_callback_var (
38- falling_factorial (a.val (), b), [a, digamma_ab](auto & vi) mutable {
39- a.adj ().array () += vi.adj ().array () * vi.val ().array () * digamma_ab;
40- });
26+ return make_callback_var (
27+ falling_factorial (a.val (), b), [a, digamma_ab](auto & vi) mutable {
28+ a.adj ().array () += vi.adj ().array () * vi.val ().array () * digamma_ab;
29+ });
4130}
4231
4332} // namespace math
0 commit comments