66#include < stan/math/prim/fun/constants.hpp>
77#include < stan/math/prim/fun/digamma.hpp>
88#include < stan/math/prim/fun/exp.hpp>
9- #include < stan/math/prim/fun/gamma_q.hpp>
9+ #include < stan/math/prim/fun/fma.hpp>
10+ #include < stan/math/prim/fun/gamma_p.hpp>
1011#include < stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12+ #include < stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13+ #include < stan/math/prim/fun/lgamma.hpp>
1114#include < stan/math/prim/fun/log.hpp>
15+ #include < stan/math/prim/fun/log1m.hpp>
1216#include < stan/math/prim/fun/max_size.hpp>
1317#include < stan/math/prim/fun/scalar_seq_view.hpp>
1418#include < stan/math/prim/fun/size.hpp>
1519#include < stan/math/prim/fun/size_zero.hpp>
1620#include < stan/math/prim/fun/tgamma.hpp>
17- #include < stan/math/prim/fun/value_of.hpp>
21+ #include < stan/math/prim/fun/value_of_rec.hpp>
22+ #include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1823#include < stan/math/prim/functor/partials_propagator.hpp>
1924#include < cmath>
25+ #include < optional>
2026
2127namespace stan {
2228namespace math {
29+ namespace internal {
30+
31+ /* *
32+ * Computes log q and d(log q) / d(alpha) using continued fraction.
33+ */
34+ template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1,
35+ typename T2>
36+ inline std::optional<std::pair<return_type_t <T1, T2>, return_type_t <T1, T2>>>
37+ eval_q_cf (const T1& alpha, const T2& beta_y) {
38+ using scalar_t = return_type_t <T1, T2>;
39+ using ret_t = std::pair<scalar_t , scalar_t >;
40+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
41+ std::pair<double , double > log_q_result
42+ = log_gamma_q_dgamma (value_of (alpha), value_of (beta_y));
43+ if (likely (std::isfinite (log_q_result.first ))) {
44+ return std::optional{log_q_result};
45+ } else {
46+ return std::optional<ret_t >{std::nullopt };
47+ }
48+ } else {
49+ ret_t out{internal::log_q_gamma_cf (alpha, beta_y), 0.0 };
50+ if (unlikely (!std::isfinite (value_of_rec (out.first )))) {
51+ return std::optional<ret_t >{std::nullopt };
52+ }
53+ if constexpr (is_autodiff_v<T_shape>) {
54+ if constexpr (!partials_fvar) {
55+ out.second
56+ = grad_reg_inc_gamma (alpha, beta_y, tgamma (alpha), digamma (alpha))
57+ / exp (out.first );
58+ } else {
59+ auto alpha_unit = alpha;
60+ alpha_unit.d_ = 1 ;
61+ auto beta_y_unit = beta_y;
62+ beta_y_unit.d_ = 0 ;
63+ auto log_Q_fvar = internal::log_q_gamma_cf (alpha_unit, beta_y_unit);
64+ out.second = log_Q_fvar.d_ ;
65+ }
66+ }
67+ return std::optional{out};
68+ }
69+ }
70+
71+ /* *
72+ * Computes log q and d(log q) / d(alpha) using log1m.
73+ */
74+ template <bool partials_fvar, typename T_shape, typename T1, typename T2>
75+ inline std::optional<std::pair<return_type_t <T1, T2>, return_type_t <T1, T2>>>
76+ eval_q_log1m (const T1& alpha, const T2& beta_y) {
77+ using scalar_t = return_type_t <T1, T2>;
78+ using ret_t = std::pair<scalar_t , scalar_t >;
79+ ret_t out{log1m (gamma_p (alpha, beta_y)), 0.0 };
80+ if (unlikely (!std::isfinite (value_of_rec (out.first )))) {
81+ return std::optional<ret_t >{std::nullopt };
82+ }
83+ if constexpr (is_autodiff_v<T_shape>) {
84+ if constexpr (partials_fvar) {
85+ auto alpha_unit = alpha;
86+ alpha_unit.d_ = 1 ;
87+ auto beta_unit = beta_y;
88+ beta_unit.d_ = 0 ;
89+ auto log_Q_fvar = log1m (gamma_p (alpha_unit, beta_unit));
90+ out.second = log_Q_fvar.d_ ;
91+ } else {
92+ out.second = -grad_reg_lower_inc_gamma (alpha, beta_y) / exp (out.first );
93+ }
94+ }
95+ return std::optional{out};
96+ }
97+ } // namespace internal
2398
2499template <typename T_y, typename T_shape, typename T_inv_scale>
25100inline return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (
26101 const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27- using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
28102 using std::exp;
29103 using std::log;
30- using std::pow ;
104+ using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale> ;
31105 using T_y_ref = ref_type_t <T_y>;
32106 using T_alpha_ref = ref_type_t <T_shape>;
33107 using T_beta_ref = ref_type_t <T_inv_scale>;
@@ -51,61 +125,70 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
51125 scalar_seq_view<T_y_ref> y_vec (y_ref);
52126 scalar_seq_view<T_alpha_ref> alpha_vec (alpha_ref);
53127 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
54- size_t N = max_size (y, alpha, beta);
55-
56- // Explicit return for extreme values
57- // The gradients are technically ill-defined, but treated as zero
58- for (size_t i = 0 ; i < stan::math::size (y); i++) {
59- if (y_vec.val (i) == 0 ) {
60- // LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61- return ops_partials.build (0.0 );
62- }
63- }
128+ const size_t N = max_size (y, alpha, beta);
129+
130+ constexpr bool is_y_fvar = is_fvar_v<scalar_type_t <T_y>>;
131+ constexpr bool is_shape_fvar = is_fvar_v<scalar_type_t <T_shape>>;
132+ constexpr bool is_beta_fvar = is_fvar_v<scalar_type_t <T_inv_scale>>;
133+ constexpr bool any_fvar = is_y_fvar || is_shape_fvar || is_beta_fvar;
134+ constexpr bool partials_fvar = is_fvar_v<T_partials_return>;
64135
65136 for (size_t n = 0 ; n < N; n++) {
66137 // Explicit results for extreme values
67138 // The gradients are technically ill-defined, but treated as zero
68- if (y_vec.val (n) == INFTY) {
69- // LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
139+ const T_partials_return y_val = y_vec.val (n);
140+ if (y_val == 0.0 ) {
141+ continue ;
142+ }
143+ if (y_val == INFTY) {
70144 return ops_partials.build (negative_infinity ());
71145 }
72146
73- const T_partials_return y_dbl = y_vec.val (n);
74- const T_partials_return alpha_dbl = alpha_vec.val (n);
75- const T_partials_return beta_dbl = beta_vec.val (n);
76- const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
147+ const T_partials_return alpha_val = alpha_vec.val (n);
148+ const T_partials_return beta_val = beta_vec.val (n);
77149
78- // Qn = 1 - Pn
79- const T_partials_return Qn = gamma_q (alpha_dbl, beta_y_dbl);
80- const T_partials_return log_Qn = log (Qn);
150+ const T_partials_return beta_y = beta_val * y_val;
151+ if (beta_y == INFTY) {
152+ return ops_partials.build (negative_infinity ());
153+ }
154+ std::optional<std::pair<T_partials_return, T_partials_return>> result;
155+ if (beta_y > alpha_val + 1.0 ) {
156+ result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val,
157+ beta_y);
158+ } else {
159+ result
160+ = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y);
161+ if (!result && beta_y > 0.0 ) {
162+ // Fallback to continued fraction if log1m fails
163+ result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(
164+ alpha_val, beta_y);
165+ }
166+ }
167+ if (unlikely (!result)) {
168+ return ops_partials.build (negative_infinity ());
169+ }
81170
82- P += log_Qn ;
171+ P += result-> first ;
83172
84- if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85- const T_partials_return log_y_dbl = log (y_dbl);
86- const T_partials_return log_beta_dbl = log (beta_dbl);
87- const T_partials_return log_pdf
88- = alpha_dbl * log_beta_dbl - lgamma (alpha_dbl)
89- + (alpha_dbl - 1.0 ) * log_y_dbl - beta_y_dbl;
90- const T_partials_return common_term = exp (log_pdf - log_Qn);
173+ if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
174+ const T_partials_return log_y = log (y_val);
175+ const T_partials_return alpha_minus_one = fma (alpha_val, log_y, -log_y);
176+
177+ const T_partials_return log_pdf = alpha_val * log (beta_val)
178+ - lgamma (alpha_val) + alpha_minus_one
179+ - beta_y;
180+
181+ const T_partials_return hazard = exp (log_pdf - result->first ); // f/Q
91182
92183 if constexpr (is_autodiff_v<T_y>) {
93- // d/dy log(1-F(y)) = -f(y)/(1-F(y))
94- partials<0 >(ops_partials)[n] -= common_term;
184+ partials<0 >(ops_partials)[n] -= hazard;
95185 }
96186 if constexpr (is_autodiff_v<T_inv_scale>) {
97- // d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98- partials<2 >(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
187+ partials<2 >(ops_partials)[n] -= (y_val / beta_val) * hazard;
99188 }
100189 }
101-
102190 if constexpr (is_autodiff_v<T_shape>) {
103- const T_partials_return digamma_val = digamma (alpha_dbl);
104- const T_partials_return gamma_val = tgamma (alpha_dbl);
105- // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106- partials<1 >(ops_partials)[n]
107- += grad_reg_inc_gamma (alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108- / Qn;
191+ partials<1 >(ops_partials)[n] += result->second ;
109192 }
110193 }
111194 return ops_partials.build (P);
0 commit comments