Skip to content

Commit ac8c21a

Browse files
authored
Merge pull request #3275 from stan-dev/fix-gamma-lccdf-v3
Fix gamma lccdf
2 parents 524939b + 62393a1 commit ac8c21a

6 files changed

Lines changed: 428 additions & 45 deletions

File tree

stan/math/fwd/meta/is_fvar.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,5 @@ struct is_fvar<T,
2121
std::enable_if_t<internal::is_fvar_impl<std::decay_t<T>>::value>>
2222
: std::true_type {};
2323

24-
template <typename T>
25-
inline constexpr bool is_fvar_v = is_fvar<T>::value;
26-
2724
} // namespace stan
2825
#endif
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
2+
#define STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/digamma.hpp>
7+
#include <stan/math/prim/fun/exp.hpp>
8+
#include <stan/math/prim/fun/fabs.hpp>
9+
#include <stan/math/prim/fun/gamma_p.hpp>
10+
#include <stan/math/prim/fun/gamma_q.hpp>
11+
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12+
#include <stan/math/prim/fun/inv.hpp>
13+
#include <stan/math/prim/fun/lgamma.hpp>
14+
#include <stan/math/prim/fun/log.hpp>
15+
#include <stan/math/prim/fun/log1m.hpp>
16+
#include <stan/math/prim/fun/tgamma.hpp>
17+
#include <stan/math/prim/fun/value_of.hpp>
18+
#include <stan/math/prim/fun/value_of_rec.hpp>
19+
#include <cmath>
20+
21+
namespace stan {
22+
namespace math {
23+
24+
namespace internal {
25+
26+
constexpr double LOG_Q_GAMMA_CF_PRECISION = 1.49012e-12;
27+
28+
/**
29+
* Compute log(Q(a,z)) using continued fraction expansion for upper incomplete
30+
* gamma function.
31+
*
32+
* @tparam T_a Type of shape parameter a (double or fvar types)
33+
* @tparam T_z Type of value parameter z (double or fvar types)
34+
* @param a Shape parameter
35+
* @param z Value at which to evaluate
36+
* @param precision Convergence threshold, default of sqrt(machine_epsilon)
37+
* @param max_steps Maximum number of continued fraction iterations
38+
* @return log(Q(a,z)) with the return type of T_a and T_z
39+
*/
40+
template <typename T_a, typename T_z>
41+
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
42+
double precision
43+
= LOG_Q_GAMMA_CF_PRECISION,
44+
int max_steps = 250) {
45+
using T_return = return_type_t<T_a, T_z>;
46+
const T_return log_prefactor = a * log(z) - z - lgamma(a);
47+
48+
T_return b_init = z + 1.0 - a;
49+
T_return C = (fabs(value_of_rec(b_init)) >= EPSILON)
50+
? b_init
51+
: std::decay_t<decltype(b_init)>(EPSILON);
52+
T_return D = 0.0;
53+
T_return f = C;
54+
for (int i = 1; i <= max_steps; ++i) {
55+
T_a an = -i * (i - a);
56+
const T_return b = b_init + 2.0 * i;
57+
D = b + an * D;
58+
D = (fabs(value_of_rec(D)) >= EPSILON) ? D
59+
: std::decay_t<decltype(D)>(EPSILON);
60+
C = b + an / C;
61+
C = (fabs(value_of_rec(C)) >= EPSILON) ? C
62+
: std::decay_t<decltype(C)>(EPSILON);
63+
D = inv(D);
64+
const T_return delta = C * D;
65+
f *= delta;
66+
const double delta_m1 = fabs(value_of_rec(delta) - 1.0);
67+
if (delta_m1 < precision) {
68+
break;
69+
}
70+
}
71+
return log_prefactor - log(f);
72+
}
73+
74+
} // namespace internal
75+
76+
/**
77+
* Compute log(Q(a,z)) and its gradient with respect to a using continued
78+
* fraction expansion, where Q(a,z) = Gamma(a,z) / Gamma(a) is the regularized
79+
* upper incomplete gamma function.
80+
*
81+
* This uses a continued fraction representation for numerical stability when
82+
* computing the upper incomplete gamma function in log space, along with
83+
* analytical gradient computation.
84+
*
85+
* @tparam T_a type of the shape parameter
86+
* @tparam T_z type of the value parameter
87+
* @param a shape parameter (must be positive)
88+
* @param z value parameter (must be non-negative)
89+
* @param precision convergence threshold, default of sqrt(machine_epsilon)
90+
* @param max_steps maximum iterations for continued fraction
91+
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
92+
*/
93+
template <typename T_a, typename T_z>
94+
inline std::pair<return_type_t<T_a, T_z>, return_type_t<T_a, T_z>>
95+
log_gamma_q_dgamma(const T_a& a, const T_z& z,
96+
double precision = internal::LOG_Q_GAMMA_CF_PRECISION,
97+
int max_steps = 250) {
98+
using T_return = return_type_t<T_a, T_z>;
99+
const double a_val = value_of(a);
100+
const double z_val = value_of(z);
101+
// For z > a + 1, use continued fraction for better numerical stability
102+
if (z_val > a_val + 1.0) {
103+
std::pair<T_return, T_return> result{
104+
internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
105+
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
106+
// grad_reg_inc_gamma computes dQ/da
107+
const T_return Q_val = exp(result.first);
108+
const double dQ_da
109+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
110+
result.second = dQ_da / Q_val;
111+
return result;
112+
} else {
113+
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
114+
const double P_val = gamma_p(a_val, z_val);
115+
std::pair<T_return, T_return> result{log1m(P_val), 0.0};
116+
// Gradient: d/da log(Q) = (1/Q) * dQ/da
117+
// grad_reg_inc_gamma computes dQ/da
118+
const T_return Q_val = exp(result.first);
119+
if (Q_val > 0) {
120+
const double dQ_da
121+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
122+
result.second = dQ_da / Q_val;
123+
} else {
124+
// Fallback if Q rounds to zero - use asymptotic approximation
125+
result.second = log(z_val) - digamma(a_val);
126+
}
127+
return result;
128+
}
129+
}
130+
131+
} // namespace math
132+
} // namespace stan
133+
134+
#endif

stan/math/prim/meta/is_fvar.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ namespace stan {
1414
template <typename T, typename = void>
1515
struct is_fvar : std::false_type {};
1616

17+
template <typename T>
18+
inline constexpr bool is_fvar_v = is_fvar<T>::value;
19+
1720
/** \ingroup type_trait
1821
* Specialization for pointers returns the underlying value the pointer is
1922
* pointing to.

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 125 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,102 @@
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/digamma.hpp>
88
#include <stan/math/prim/fun/exp.hpp>
9-
#include <stan/math/prim/fun/gamma_q.hpp>
9+
#include <stan/math/prim/fun/fma.hpp>
10+
#include <stan/math/prim/fun/gamma_p.hpp>
1011
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12+
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13+
#include <stan/math/prim/fun/lgamma.hpp>
1114
#include <stan/math/prim/fun/log.hpp>
15+
#include <stan/math/prim/fun/log1m.hpp>
1216
#include <stan/math/prim/fun/max_size.hpp>
1317
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1418
#include <stan/math/prim/fun/size.hpp>
1519
#include <stan/math/prim/fun/size_zero.hpp>
1620
#include <stan/math/prim/fun/tgamma.hpp>
17-
#include <stan/math/prim/fun/value_of.hpp>
21+
#include <stan/math/prim/fun/value_of_rec.hpp>
22+
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1823
#include <stan/math/prim/functor/partials_propagator.hpp>
1924
#include <cmath>
25+
#include <optional>
2026

2127
namespace stan {
2228
namespace math {
29+
namespace internal {
30+
31+
/**
32+
* Computes log q and d(log q) / d(alpha) using continued fraction.
33+
*/
34+
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1,
35+
typename T2>
36+
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
37+
eval_q_cf(const T1& alpha, const T2& beta_y) {
38+
using scalar_t = return_type_t<T1, T2>;
39+
using ret_t = std::pair<scalar_t, scalar_t>;
40+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
41+
std::pair<double, double> log_q_result
42+
= log_gamma_q_dgamma(value_of(alpha), value_of(beta_y));
43+
if (likely(std::isfinite(log_q_result.first))) {
44+
return std::optional{log_q_result};
45+
} else {
46+
return std::optional<ret_t>{std::nullopt};
47+
}
48+
} else {
49+
ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0};
50+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
51+
return std::optional<ret_t>{std::nullopt};
52+
}
53+
if constexpr (is_autodiff_v<T_shape>) {
54+
if constexpr (!partials_fvar) {
55+
out.second
56+
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha))
57+
/ exp(out.first);
58+
} else {
59+
auto alpha_unit = alpha;
60+
alpha_unit.d_ = 1;
61+
auto beta_y_unit = beta_y;
62+
beta_y_unit.d_ = 0;
63+
auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
64+
out.second = log_Q_fvar.d_;
65+
}
66+
}
67+
return std::optional{out};
68+
}
69+
}
70+
71+
/**
72+
* Computes log q and d(log q) / d(alpha) using log1m.
73+
*/
74+
template <bool partials_fvar, typename T_shape, typename T1, typename T2>
75+
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
76+
eval_q_log1m(const T1& alpha, const T2& beta_y) {
77+
using scalar_t = return_type_t<T1, T2>;
78+
using ret_t = std::pair<scalar_t, scalar_t>;
79+
ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0};
80+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
81+
return std::optional<ret_t>{std::nullopt};
82+
}
83+
if constexpr (is_autodiff_v<T_shape>) {
84+
if constexpr (partials_fvar) {
85+
auto alpha_unit = alpha;
86+
alpha_unit.d_ = 1;
87+
auto beta_unit = beta_y;
88+
beta_unit.d_ = 0;
89+
auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
90+
out.second = log_Q_fvar.d_;
91+
} else {
92+
out.second = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first);
93+
}
94+
}
95+
return std::optional{out};
96+
}
97+
} // namespace internal
2398

2499
template <typename T_y, typename T_shape, typename T_inv_scale>
25100
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
26101
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27-
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
28102
using std::exp;
29103
using std::log;
30-
using std::pow;
104+
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
31105
using T_y_ref = ref_type_t<T_y>;
32106
using T_alpha_ref = ref_type_t<T_shape>;
33107
using T_beta_ref = ref_type_t<T_inv_scale>;
@@ -51,61 +125,70 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
51125
scalar_seq_view<T_y_ref> y_vec(y_ref);
52126
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
53127
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
54-
size_t N = max_size(y, alpha, beta);
55-
56-
// Explicit return for extreme values
57-
// The gradients are technically ill-defined, but treated as zero
58-
for (size_t i = 0; i < stan::math::size(y); i++) {
59-
if (y_vec.val(i) == 0) {
60-
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61-
return ops_partials.build(0.0);
62-
}
63-
}
128+
const size_t N = max_size(y, alpha, beta);
129+
130+
constexpr bool is_y_fvar = is_fvar_v<scalar_type_t<T_y>>;
131+
constexpr bool is_shape_fvar = is_fvar_v<scalar_type_t<T_shape>>;
132+
constexpr bool is_beta_fvar = is_fvar_v<scalar_type_t<T_inv_scale>>;
133+
constexpr bool any_fvar = is_y_fvar || is_shape_fvar || is_beta_fvar;
134+
constexpr bool partials_fvar = is_fvar_v<T_partials_return>;
64135

65136
for (size_t n = 0; n < N; n++) {
66137
// Explicit results for extreme values
67138
// The gradients are technically ill-defined, but treated as zero
68-
if (y_vec.val(n) == INFTY) {
69-
// LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
139+
const T_partials_return y_val = y_vec.val(n);
140+
if (y_val == 0.0) {
141+
continue;
142+
}
143+
if (y_val == INFTY) {
70144
return ops_partials.build(negative_infinity());
71145
}
72146

73-
const T_partials_return y_dbl = y_vec.val(n);
74-
const T_partials_return alpha_dbl = alpha_vec.val(n);
75-
const T_partials_return beta_dbl = beta_vec.val(n);
76-
const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
147+
const T_partials_return alpha_val = alpha_vec.val(n);
148+
const T_partials_return beta_val = beta_vec.val(n);
77149

78-
// Qn = 1 - Pn
79-
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
80-
const T_partials_return log_Qn = log(Qn);
150+
const T_partials_return beta_y = beta_val * y_val;
151+
if (beta_y == INFTY) {
152+
return ops_partials.build(negative_infinity());
153+
}
154+
std::optional<std::pair<T_partials_return, T_partials_return>> result;
155+
if (beta_y > alpha_val + 1.0) {
156+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val,
157+
beta_y);
158+
} else {
159+
result
160+
= internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y);
161+
if (!result && beta_y > 0.0) {
162+
// Fallback to continued fraction if log1m fails
163+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(
164+
alpha_val, beta_y);
165+
}
166+
}
167+
if (unlikely(!result)) {
168+
return ops_partials.build(negative_infinity());
169+
}
81170

82-
P += log_Qn;
171+
P += result->first;
83172

84-
if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85-
const T_partials_return log_y_dbl = log(y_dbl);
86-
const T_partials_return log_beta_dbl = log(beta_dbl);
87-
const T_partials_return log_pdf
88-
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
89-
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
90-
const T_partials_return common_term = exp(log_pdf - log_Qn);
173+
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
174+
const T_partials_return log_y = log(y_val);
175+
const T_partials_return alpha_minus_one = fma(alpha_val, log_y, -log_y);
176+
177+
const T_partials_return log_pdf = alpha_val * log(beta_val)
178+
- lgamma(alpha_val) + alpha_minus_one
179+
- beta_y;
180+
181+
const T_partials_return hazard = exp(log_pdf - result->first); // f/Q
91182

92183
if constexpr (is_autodiff_v<T_y>) {
93-
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
94-
partials<0>(ops_partials)[n] -= common_term;
184+
partials<0>(ops_partials)[n] -= hazard;
95185
}
96186
if constexpr (is_autodiff_v<T_inv_scale>) {
97-
// d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98-
partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
187+
partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard;
99188
}
100189
}
101-
102190
if constexpr (is_autodiff_v<T_shape>) {
103-
const T_partials_return digamma_val = digamma(alpha_dbl);
104-
const T_partials_return gamma_val = tgamma(alpha_dbl);
105-
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106-
partials<1>(ops_partials)[n]
107-
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108-
/ Qn;
191+
partials<1>(ops_partials)[n] += result->second;
109192
}
110193
}
111194
return ops_partials.build(P);

0 commit comments

Comments
 (0)