Skip to content

Commit 9553f64

Browse files
committed
Merge branch 'feature/parameter-pack-odes' of https://github.com/stan-dev/math into feature/parameter-pack-odes
2 parents 8bf8b87 + e3e2b9e commit 9553f64

13 files changed

Lines changed: 313 additions & 188 deletions

stan/math/prim/functor/coupled_ode_system.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ struct coupled_ode_system
129129
: public coupled_ode_system_impl<
130130
std::is_arithmetic<return_type_t<T_initial, Args...>>::value, F,
131131
T_initial, Args...> {
132-
coupled_ode_system(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
132+
coupled_ode_system(const F& f,
133+
const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
133134
std::ostream* msgs, const Args&... args)
134135
: coupled_ode_system_impl<
135136
std::is_arithmetic<return_type_t<T_initial, Args...>>::value, F,

stan/math/prim/functor/integrate_ode_rk45.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ integrate_ode_rk45(const F& f, const std::vector<T1>& y0, const T_t0& t0,
2424
double relative_tolerance = 1e-6,
2525
double absolute_tolerance = 1e-6, int max_num_steps = 1e6) {
2626
internal::integrate_ode_std_vector_interface_adapter<F> f_adapted(f);
27-
auto y = ode_rk45_tol(f_adapted, to_vector(y0), t0, ts,
28-
relative_tolerance, absolute_tolerance,
29-
max_num_steps, msgs, theta, x, x_int);
27+
auto y
28+
= ode_rk45_tol(f_adapted, to_vector(y0), t0, ts, relative_tolerance,
29+
absolute_tolerance, max_num_steps, msgs, theta, x, x_int);
3030

3131
std::vector<std::vector<return_type_t<T1, T_param, T_t0, T_ts>>> y_converted;
32-
for(size_t i = 0; i < y.size(); ++i)
32+
for (size_t i = 0; i < y.size(); ++i)
3333
y_converted.push_back(to_array_1d(y[i]));
3434

3535
return y_converted;

stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ template <typename F>
2323
struct integrate_ode_std_vector_interface_adapter {
2424
const F f_;
2525

26-
integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {
27-
}
28-
29-
template<typename T0, typename T1, typename T2>
30-
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y, std::ostream* msgs,
31-
const std::vector<T2>& theta, const std::vector<double>& x,
32-
const std::vector<int>& x_int) const {
26+
integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {}
27+
28+
template <typename T0, typename T1, typename T2>
29+
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
30+
std::ostream* msgs, const std::vector<T2>& theta,
31+
const std::vector<double>& x,
32+
const std::vector<int>& x_int) const {
3333
return to_vector(f_(t, to_array_1d(y), msgs, theta, x, x_int));
3434
}
3535
};
3636

37-
}
37+
} // namespace internal
3838

3939
} // namespace math
4040
} // namespace stan

stan/math/prim/functor/ode_rk45.hpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ namespace math {
5050
* @return Solution to ODE at times \p ts
5151
*/
5252
template <typename F, typename T_initial, typename T_t0, typename T_ts,
53-
typename... Args>
54-
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, Args...>, Eigen::Dynamic, 1>>
55-
ode_rk45(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0, double t0,
56-
const std::vector<double>& ts, std::ostream* msgs,
53+
typename... Args>
54+
std::vector<
55+
Eigen::Matrix<stan::return_type_t<T_initial, Args...>, Eigen::Dynamic, 1>>
56+
ode_rk45(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
57+
double t0, const std::vector<double>& ts, std::ostream* msgs,
5758
const Args&... args) {
5859
double relative_tolerance = 1e-10;
5960
double absolute_tolerance = 1e-10;
@@ -101,10 +102,11 @@ ode_rk45(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0, doub
101102
* @return Solution to ODE at times \p ts
102103
*/
103104
template <typename F, typename T_initial, typename T_t0, typename T_ts,
104-
typename... Args>
105-
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, Args...>, Eigen::Dynamic, 1>>
106-
ode_rk45_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0_arg,
107-
T_t0 t0,
105+
typename... Args>
106+
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, Args...>,
107+
Eigen::Dynamic, 1>>
108+
ode_rk45_tol(const F& f,
109+
const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
108110
const std::vector<T_ts>& ts, double relative_tolerance,
109111
double absolute_tolerance, long int max_num_steps,
110112
std::ostream* msgs, const Args&... args) {
@@ -115,10 +117,9 @@ ode_rk45_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0_a
115117
using boost::numeric::odeint::vector_space_algebra;
116118

117119
using T_initial_or_t0 = return_type_t<T_initial, T_t0>;
118-
119-
Eigen::Matrix<T_initial_or_t0, Eigen::Dynamic, 1> y0 = y0_arg.unaryExpr([](const T_initial& val) {
120-
return T_initial_or_t0(val);
121-
});
120+
121+
Eigen::Matrix<T_initial_or_t0, Eigen::Dynamic, 1> y0 = y0_arg.unaryExpr(
122+
[](const T_initial& val) { return T_initial_or_t0(val); });
122123
const std::vector<double> ts_dbl = value_of(ts);
123124

124125
check_finite("integrate_ode_rk45", "initial state", y0);
@@ -151,7 +152,7 @@ ode_rk45_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0_a
151152
using return_t = return_type_t<T_initial, T_t0, T_ts, Args...>;
152153
// creates basic or coupled system by template specializations
153154
coupled_ode_system<F, T_initial_or_t0, Args...> coupled_system(f, y0, msgs,
154-
args...);
155+
args...);
155156

156157
// first time in the vector must be time of initial state
157158
std::vector<double> ts_vec(ts.size() + 1);
@@ -167,7 +168,7 @@ ode_rk45_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0_a
167168
// avoid recording of the initial state which is included by the
168169
// conventions of odeint in the output
169170
auto filtered_observer
170-
= [&](const Eigen::VectorXd& coupled_state, double t) -> void {
171+
= [&](const Eigen::VectorXd& coupled_state, double t) -> void {
171172
if (!observer_initial_recorded) {
172173
observer_initial_recorded = true;
173174
return;

stan/math/rev/functor/coupled_ode_system.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ struct coupled_ode_system_impl<false, F, T_initial, Args...> {
8585
* @param[in] x_int integer data
8686
* @param[in, out] msgs stream for messages
8787
*/
88-
coupled_ode_system_impl(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
88+
coupled_ode_system_impl(const F& f,
89+
const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
8990
std::ostream* msgs, const Args&... args)
9091
: f_(f),
9192
y0_(y0),
@@ -114,12 +115,12 @@ struct coupled_ode_system_impl<false, F, T_initial, Args...> {
114115
using std::vector;
115116

116117
dz_dt.resize(size());
117-
118+
118119
// Run nested autodiff in this scope
119120
nested_rev_autodiff nested;
120121

121122
Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(N_);
122-
for(size_t n = 0; n < N_; ++n)
123+
for (size_t n = 0; n < N_; ++n)
123124
y_vars(n) = z(n);
124125

125126
auto local_args_tuple = apply(
@@ -130,8 +131,8 @@ struct coupled_ode_system_impl<false, F, T_initial, Args...> {
130131
args_tuple_);
131132

132133
Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
133-
= apply([&](auto&&... args) { return f_(t, y_vars, msgs_, args...); },
134-
local_args_tuple);
134+
= apply([&](auto&&... args) { return f_(t, y_vars, msgs_, args...); },
135+
local_args_tuple);
135136

136137
check_size_match("coupled_ode_system", "dy_dt", f_y_t_vars.size(), "states",
137138
N_);

stan/math/rev/functor/cvodes_integrator.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class cvodes_integrator {
102102
const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y, N_);
103103

104104
Eigen::VectorXd dy_dt_vec
105-
= apply([&](auto&&... args) { return f_(t, y_vec, msgs_, args...); },
106-
value_of_args_tuple_);
105+
= apply([&](auto&&... args) { return f_(t, y_vec, msgs_, args...); },
106+
value_of_args_tuple_);
107107

108108
check_size_match("cvodes_integrator::rhs", "dy_dt", dy_dt_vec.size(),
109109
"states", N_);
@@ -121,7 +121,7 @@ class cvodes_integrator {
121121

122122
auto f_wrapped = [&](const Eigen::Matrix<var, Eigen::Dynamic, 1>& y) {
123123
return apply([&](auto&&... args) { return f_(t, y, msgs_, args...); },
124-
value_of_args_tuple_);
124+
value_of_args_tuple_);
125125
};
126126

127127
jacobian(f_wrapped, Eigen::Map<const Eigen::VectorXd>(y, N_), fy, Jfy);
@@ -191,7 +191,7 @@ class cvodes_integrator {
191191
relative_tolerance_(relative_tolerance),
192192
absolute_tolerance_(absolute_tolerance),
193193
max_num_steps_(max_num_steps),
194-
y0_vars_(count_vars(y0_)),
194+
y0_vars_(count_vars(y0_)),
195195
args_vars_(count_vars(args...)),
196196
coupled_ode_(f, y0_, msgs, args...),
197197
coupled_state_(coupled_ode_.initial_state()) {
@@ -249,7 +249,8 @@ class cvodes_integrator {
249249
* @return std::vector of Eigen::Matrix of the states of the ODE, one for each
250250
* solution time (excluding the initial state)
251251
*/
252-
std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> integrate() { // NOLINT(runtime/int)
252+
std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>>
253+
integrate() { // NOLINT(runtime/int)
253254
std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y;
254255

255256
const double t0_dbl = value_of(t0_);
@@ -312,9 +313,8 @@ class cvodes_integrator {
312313

313314
y.emplace_back(apply(
314315
[&](auto&&... args) {
315-
return ode_store_sensitivities(f_, coupled_state_, y0_,
316-
t0_, ts_[n], msgs_,
317-
args...);
316+
return ode_store_sensitivities(f_, coupled_state_, y0_, t0_,
317+
ts_[n], msgs_, args...);
318318
},
319319
args_tuple_));
320320

stan/math/rev/functor/integrate_ode_bdf.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ integrate_ode_bdf(const F& f, const std::vector<T_initial>& y0, const T_t0& t0,
2525
double absolute_tolerance = 1e-10,
2626
long int max_num_steps = 1e8) { // NOLINT(runtime/int)
2727
internal::integrate_ode_std_vector_interface_adapter<F> f_adapted(f);
28-
auto y = ode_bdf_tol(f_adapted, to_vector(y0), t0, ts,
29-
relative_tolerance, absolute_tolerance,
30-
max_num_steps, msgs, theta, x, x_int);
28+
auto y
29+
= ode_bdf_tol(f_adapted, to_vector(y0), t0, ts, relative_tolerance,
30+
absolute_tolerance, max_num_steps, msgs, theta, x, x_int);
3131

32-
std::vector<std::vector<return_type_t<T_initial, T_param, T_t0, T_ts>>> y_converted;
33-
for(size_t i = 0; i < y.size(); ++i)
32+
std::vector<std::vector<return_type_t<T_initial, T_param, T_t0, T_ts>>>
33+
y_converted;
34+
for (size_t i = 0; i < y.size(); ++i)
3435
y_converted.push_back(to_array_1d(y[i]));
3536

3637
return y_converted;

stan/math/rev/functor/ode_bdf.hpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ namespace math {
4545
*/
4646
template <typename F, typename T_initial, typename T_t0, typename T_ts,
4747
typename... T_Args>
48-
std::vector<Eigen::Matrix<
49-
stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>, Eigen::Dynamic, 1>>
50-
ode_bdf(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0, const T_t0& t0,
51-
const std::vector<T_ts>& ts, std::ostream* msgs,
48+
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>,
49+
Eigen::Dynamic, 1>>
50+
ode_bdf(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
51+
const T_t0& t0, const std::vector<T_ts>& ts, std::ostream* msgs,
5252
const T_Args&... args) {
5353
double relative_tolerance = 1e-10;
5454
double absolute_tolerance = 1e-10;
@@ -93,15 +93,15 @@ ode_bdf(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0, const
9393
*/
9494
template <typename F, typename T_initial, typename T_t0, typename T_ts,
9595
typename... T_Args>
96-
std::vector<Eigen::Matrix<
97-
stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>, Eigen::Dynamic, 1>>
98-
ode_bdf_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0, const T_t0& t0,
99-
const std::vector<T_ts>& ts, double relative_tolerance,
100-
double absolute_tolerance, long int max_num_steps,
101-
std::ostream* msgs, const T_Args&... args) {
102-
stan::math::cvodes_integrator<CV_BDF, F, T_initial, T_t0, T_ts, T_Args...> integrator(
103-
f, y0, t0, ts, relative_tolerance, absolute_tolerance,
104-
max_num_steps, msgs, args...);
96+
std::vector<Eigen::Matrix<stan::return_type_t<T_initial, T_t0, T_ts, T_Args...>,
97+
Eigen::Dynamic, 1>>
98+
ode_bdf_tol(const F& f, const Eigen::Matrix<T_initial, Eigen::Dynamic, 1>& y0,
99+
const T_t0& t0, const std::vector<T_ts>& ts,
100+
double relative_tolerance, double absolute_tolerance,
101+
long int max_num_steps, std::ostream* msgs, const T_Args&... args) {
102+
stan::math::cvodes_integrator<CV_BDF, F, T_initial, T_t0, T_ts, T_Args...>
103+
integrator(f, y0, t0, ts, relative_tolerance, absolute_tolerance,
104+
max_num_steps, msgs, args...);
105105

106106
return integrator.integrate();
107107
}

stan/math/rev/functor/ode_store_sensitivities.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> ode_store_sensitivities(const F& f,
4646
Eigen::Matrix<var, Eigen::Dynamic, 1> yt(N);
4747

4848
Eigen::VectorXd y = coupled_state.head(N);
49-
49+
5050
Eigen::VectorXd f_y_t;
5151
if (is_var<T_t>::value)
5252
f_y_t = f(value_of(t), y, msgs, value_of(args)...);
@@ -83,25 +83,24 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> ode_store_sensitivities(const F& f,
8383
}
8484

8585
varis_ptr = save_varis(varis_ptr, t0);
86-
if(t0_vars > 0) {
86+
if (t0_vars > 0) {
8787
double dyt_dt0 = 0.0;
8888
for (std::size_t k = 0; k < y0_vars; ++k) {
89-
// dy[j]_dtheta[k]
90-
// theta[k].vi_
91-
dyt_dt0 += -f_y0_t0[k] * coupled_state(N + y0_vars * k + j);
89+
// dy[j]_dtheta[k]
90+
// theta[k].vi_
91+
dyt_dt0 += -f_y0_t0[k] * coupled_state(N + y0_vars * k + j);
9292
}
9393
*partials_ptr = dyt_dt0;
9494
partials_ptr++;
9595
}
9696

9797
varis_ptr = save_varis(varis_ptr, t);
98-
if(t_vars > 0) {
98+
if (t_vars > 0) {
9999
*partials_ptr = f_y_t[j];
100100
partials_ptr++;
101101
}
102102

103-
yt(j) = new precomputed_gradients_vari(y(j), total_vars,
104-
varis, partials);
103+
yt(j) = new precomputed_gradients_vari(y(j), total_vars, varis, partials);
105104
}
106105

107106
return yt;

test/unit/math/rev/functor/integrate_ode_bdf_rev_test.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ TEST(StanAgradRevOde_integrate_ode_bdf, t0_as_param_AD) {
270270

271271
std::vector<double> theta{0.15};
272272
std::vector<double> y0{1.0, 0.0};
273-
std::vector<double> ts = { 5.0, 10.0 };
273+
std::vector<double> ts = {5.0, 10.0};
274274

275275
std::vector<double> x;
276276
std::vector<int> x_int;
@@ -293,12 +293,16 @@ TEST(StanAgradRevOde_integrate_ode_bdf, t0_as_param_AD) {
293293
EXPECT_FLOAT_EQ(t0v.adj(), -0.38494826636037426937);
294294
stan::math::set_zero_all_adjoints();
295295
};
296-
res = integrate_ode_bdf(ode, y0, t0v, ts, theta, x, x_int, nullptr, 1e-10, 1e-10, 1e6);
296+
res = integrate_ode_bdf(ode, y0, t0v, ts, theta, x, x_int, nullptr, 1e-10,
297+
1e-10, 1e6);
297298
test_ad();
298-
res = integrate_ode_bdf(ode, y0v, t0v, ts, theta, x, x_int, nullptr, 1e-10, 1e-10, 1e6);
299+
res = integrate_ode_bdf(ode, y0v, t0v, ts, theta, x, x_int, nullptr, 1e-10,
300+
1e-10, 1e6);
299301
test_ad();
300-
res = integrate_ode_bdf(ode, y0, t0v, ts, thetav, x, x_int, nullptr, 1e-10, 1e-10, 1e6);
302+
res = integrate_ode_bdf(ode, y0, t0v, ts, thetav, x, x_int, nullptr, 1e-10,
303+
1e-10, 1e6);
301304
test_ad();
302-
res = integrate_ode_bdf(ode, y0v, t0v, ts, thetav, x, x_int, nullptr, 1e-10, 1e-10, 1e6);
305+
res = integrate_ode_bdf(ode, y0v, t0v, ts, thetav, x, x_int, nullptr, 1e-10,
306+
1e-10, 1e6);
303307
test_ad();
304308
}

0 commit comments

Comments
 (0)