Skip to content

Commit ae09118

Browse files
authored
Merge pull request #1841 from andrjohns/read_autodiff
Eigen member function for reading values and derivatives from autodiff matrices
2 parents 89b43ef + d2a942b commit ae09118

8 files changed

Lines changed: 613 additions & 0 deletions

File tree

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
#include <stan/math/fwd/fun/proj.hpp>
9494
#include <stan/math/fwd/fun/quad_form.hpp>
9595
#include <stan/math/fwd/fun/quad_form_sym.hpp>
96+
#include <stan/math/fwd/fun/read_fvar.hpp>
9697
#include <stan/math/fwd/fun/rising_factorial.hpp>
9798
#include <stan/math/fwd/fun/round.hpp>
9899
#include <stan/math/fwd/fun/sin.hpp>

stan/math/fwd/fun/Eigen_NumTraits.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/core.hpp>
6+
#include <stan/math/fwd/fun/read_fvar.hpp>
67
#include <stan/math/fwd/core.hpp>
78
#include <stan/math/fwd/core/std_numeric_limits.hpp>
89
#include <limits>
@@ -163,5 +164,17 @@ struct ScalarBinaryOpTraits<std::complex<stan::math::fvar<T>>,
163164
using ReturnType = std::complex<stan::math::fvar<T>>;
164165
};
165166

167+
namespace internal {
168+
169+
/**
170+
* Enable linear access of inputs when using read_fvar.
171+
*/
172+
template <typename EigFvar, typename EigOut>
173+
struct functor_has_linear_access<
174+
stan::math::read_fvar_functor<EigFvar, EigOut>> {
175+
enum { ret = 1 };
176+
};
177+
178+
} // namespace internal
166179
} // namespace Eigen
167180
#endif

stan/math/fwd/fun/read_fvar.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#ifndef STAN_MATH_FWD_FUN_READ_FVAR_HPP
2+
#define STAN_MATH_FWD_FUN_READ_FVAR_HPP
3+
4+
#include <stan/math/fwd/meta.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* Functor for extracting the values and tangents from a matrix of fvar.
11+
* This functor is called using Eigen's NullaryExpr framework.
12+
*/
13+
template <typename EigFvar, typename EigOut>
14+
class read_fvar_functor {
15+
const EigFvar& var_mat;
16+
EigOut& val_mat;
17+
18+
public:
19+
read_fvar_functor(const EigFvar& arg1, EigOut& arg2)
20+
: var_mat(arg1), val_mat(arg2) {}
21+
22+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
23+
val_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).val_;
24+
return var_mat.coeffRef(row, col).d_;
25+
}
26+
27+
inline decltype(auto) operator()(Eigen::Index index) const {
28+
val_mat.coeffRef(index) = var_mat.coeffRef(index).val_;
29+
return var_mat.coeffRef(index).d_;
30+
}
31+
};
32+
33+
/**
34+
* Function applying the read_fvar_functor to extract the values
35+
* and tangets of a given fvar matrix into separate matrices.
36+
*
37+
* @tparam EigFvar type of the Eigen container of fvar.
38+
* @tparam EigOut type of the Eigen containers to copy to
39+
* @param[in] FvarMat Input Eigen container of fvar.
40+
* @param[in] ValMat Output Eigen container of values.
41+
* @param[in] DMat Output Eigen container of tangents.
42+
*/
43+
template <typename EigFvar, typename EigOut>
44+
inline void read_fvar(const EigFvar& FvarMat, EigOut& ValMat, EigOut& DMat) {
45+
DMat = EigOut::NullaryExpr(
46+
FvarMat.rows(), FvarMat.cols(),
47+
read_fvar_functor<const EigFvar, EigOut>(FvarMat, ValMat));
48+
}
49+
50+
} // namespace math
51+
} // namespace stan
52+
#endif

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
#include <stan/math/rev/fun/proj.hpp>
116116
#include <stan/math/rev/fun/quad_form.hpp>
117117
#include <stan/math/rev/fun/quad_form_sym.hpp>
118+
#include <stan/math/rev/fun/read_var.hpp>
118119
#include <stan/math/rev/fun/rising_factorial.hpp>
119120
#include <stan/math/rev/fun/round.hpp>
120121
#include <stan/math/rev/fun/rows_dot_product.hpp>

stan/math/rev/fun/Eigen_NumTraits.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/core.hpp>
6+
#include <stan/math/rev/fun/read_var.hpp>
67
#include <stan/math/rev/meta.hpp>
78
#include <stan/math/rev/core.hpp>
89
#include <stan/math/rev/core/std_numeric_limits.hpp>
@@ -210,6 +211,40 @@ struct ScalarBinaryOpTraits<std::complex<stan::math::var>,
210211
};
211212

212213
namespace internal {
214+
215+
/**
216+
* Enable linear access of inputs when using read_vi_val_adj.
217+
*/
218+
template <typename EigVar, typename EigVari, typename EigDbl>
219+
struct functor_has_linear_access<
220+
stan::math::vi_val_adj_functor<EigVar, EigVari, EigDbl>> {
221+
enum { ret = 1 };
222+
};
223+
224+
/**
225+
* Enable linear access of inputs when using read_val_adj.
226+
*/
227+
template <typename EigVar, typename EigDbl>
228+
struct functor_has_linear_access<stan::math::val_adj_functor<EigVar, EigDbl>> {
229+
enum { ret = 1 };
230+
};
231+
232+
/**
233+
* Enable linear access of inputs when using read_vi_val.
234+
*/
235+
template <typename EigVar, typename EigVari>
236+
struct functor_has_linear_access<stan::math::vi_val_functor<EigVar, EigVari>> {
237+
enum { ret = 1 };
238+
};
239+
240+
/**
241+
* Enable linear access of inputs when using read_vi_adj.
242+
*/
243+
template <typename EigVar, typename EigVari>
244+
struct functor_has_linear_access<stan::math::vi_adj_functor<EigVar, EigVari>> {
245+
enum { ret = 1 };
246+
};
247+
213248
/**
214249
* Partial specialization of Eigen's remove_all struct to stop
215250
* Eigen removing pointer from vari* variables

stan/math/rev/fun/read_var.hpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#ifndef STAN_MATH_REV_FUN_READ_VAR_HPP
2+
#define STAN_MATH_REV_FUN_READ_VAR_HPP
3+
4+
#include <stan/math/prim/meta/require_generics.hpp>
5+
#include <stan/math/rev/core/var.hpp>
6+
#include <stan/math/rev/core/vari.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Functor for extracting the vari*, values, and adjoints from a matrix of var.
13+
* This functor is called using Eigen's NullaryExpr framework, which takes care
14+
* of the indexing. This removes the need to programmatically account for
15+
* whether the input is row- or column-major.
16+
*/
17+
template <typename EigRev, typename EigVari, typename EigDbl>
18+
class vi_val_adj_functor {
19+
const EigRev& var_mat;
20+
EigVari& vi_mat;
21+
EigDbl& val_mat;
22+
23+
public:
24+
vi_val_adj_functor(const EigRev& arg1, EigVari& arg2, EigDbl& arg3)
25+
: var_mat(arg1), vi_mat(arg2), val_mat(arg3) {}
26+
27+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
28+
vi_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).vi_;
29+
val_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).vi_->val_;
30+
return var_mat.coeffRef(row, col).vi_->adj_;
31+
}
32+
33+
inline decltype(auto) operator()(Eigen::Index index) const {
34+
vi_mat.coeffRef(index) = var_mat.coeffRef(index).vi_;
35+
val_mat.coeffRef(index) = var_mat.coeffRef(index).vi_->val_;
36+
return var_mat.coeffRef(index).vi_->adj_;
37+
}
38+
};
39+
40+
/**
41+
* Functor for extracting the values and adjoints from a matrix of var or vari.
42+
* This functor is called using Eigen's NullaryExpr framework.
43+
*/
44+
template <typename EigRev, typename EigDbl>
45+
class val_adj_functor {
46+
const EigRev& var_mat;
47+
EigDbl& val_mat;
48+
49+
public:
50+
val_adj_functor(const EigRev& arg1, EigDbl& arg2)
51+
: var_mat(arg1), val_mat(arg2) {}
52+
53+
template <typename T = EigRev, require_st_same<T, var>* = nullptr>
54+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
55+
val_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).vi_->val_;
56+
return var_mat.coeffRef(row, col).vi_->adj_;
57+
}
58+
59+
template <typename T = EigRev, require_st_same<T, var>* = nullptr>
60+
inline decltype(auto) operator()(Eigen::Index index) const {
61+
val_mat.coeffRef(index) = var_mat.coeffRef(index).vi_->val_;
62+
return var_mat.coeffRef(index).vi_->adj_;
63+
}
64+
65+
template <typename T = EigRev, require_st_same<T, vari*>* = nullptr>
66+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
67+
val_mat.coeffRef(row, col) = var_mat.coeffRef(row, col)->val_;
68+
return var_mat.coeffRef(row, col)->adj_;
69+
}
70+
71+
template <typename T = EigRev, require_st_same<T, vari*>* = nullptr>
72+
inline decltype(auto) operator()(Eigen::Index index) const {
73+
val_mat.coeffRef(index) = var_mat.coeffRef(index)->val_;
74+
return var_mat.coeffRef(index)->adj_;
75+
}
76+
};
77+
78+
/**
79+
* Functor for extracting the varis and values from a matrix of var.
80+
* This functor is called using Eigen's NullaryExpr framework.
81+
*/
82+
template <typename EigVar, typename EigVari>
83+
class vi_val_functor {
84+
const EigVar& var_mat;
85+
EigVari& vi_mat;
86+
87+
public:
88+
vi_val_functor(const EigVar& arg1, EigVari& arg2)
89+
: var_mat(arg1), vi_mat(arg2) {}
90+
91+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
92+
vi_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).vi_;
93+
return var_mat.coeffRef(row, col).vi_->val_;
94+
}
95+
96+
inline decltype(auto) operator()(Eigen::Index index) const {
97+
vi_mat.coeffRef(index) = var_mat.coeffRef(index).vi_;
98+
return var_mat.coeffRef(index).vi_->val_;
99+
}
100+
};
101+
102+
/**
103+
* Functor for extracting the varis and adjoints from a matrix of var.
104+
* This functor is called using Eigen's NullaryExpr framework.
105+
*/
106+
template <typename EigVar, typename EigVari>
107+
class vi_adj_functor {
108+
const EigVar& var_mat;
109+
EigVari& vi_mat;
110+
111+
public:
112+
vi_adj_functor(const EigVar& arg1, EigVari& arg2)
113+
: var_mat(arg1), vi_mat(arg2) {}
114+
115+
inline decltype(auto) operator()(Eigen::Index row, Eigen::Index col) const {
116+
vi_mat.coeffRef(row, col) = var_mat.coeffRef(row, col).vi_;
117+
return var_mat.coeffRef(row, col).vi_->adj_;
118+
}
119+
120+
inline decltype(auto) operator()(Eigen::Index index) const {
121+
vi_mat.coeffRef(index) = var_mat.coeffRef(index).vi_;
122+
return var_mat.coeffRef(index).vi_->adj_;
123+
}
124+
};
125+
126+
/**
127+
* Function applying the vi_val_adj_functor to extract the vari*, values,
128+
* and adjoints of a given var matrix into separate matrices.
129+
*
130+
* @tparam EigVar type of the Eigen container of var.
131+
* @tparam EigVari type of the Eigen container of vari to be copied to.
132+
* @tparam EigDbl type of the Eigen container of doubles to be copied to.
133+
* @param[in] VarMat Input Eigen container of var.
134+
* @param[in] VariMat Output Eigen container of vari.
135+
* @param[in] ValMat Output Eigen container of values.
136+
* @param[in] AdjMat Output Eigen container of tangents.
137+
*/
138+
template <typename EigVar, typename EigVari, typename EigDbl>
139+
inline void read_vi_val_adj(const EigVar& VarMat, EigVari& VariMat,
140+
EigDbl& ValMat, EigDbl& AdjMat) {
141+
AdjMat
142+
= EigDbl::NullaryExpr(VarMat.rows(), VarMat.cols(),
143+
vi_val_adj_functor<const EigVar, EigVari, EigDbl>(
144+
VarMat, VariMat, ValMat));
145+
}
146+
147+
/**
148+
* Function applying the val_adj_functor to extract the values
149+
* and adjoints of a given var or vari matrix into separate matrices.
150+
*
151+
* @tparam EigRev type of the Eigen container of var or vari.
152+
* @tparam EigDbl type of the Eigen container of doubles to be copied to.
153+
* @param[in] VarMat Input Eigen container of var.
154+
* @param[in] ValMat Output Eigen container of values.
155+
* @param[in] AdjMat Output Eigen container of adjoints.
156+
*/
157+
template <typename EigRev, typename EigDbl>
158+
inline void read_val_adj(const EigRev& VarMat, EigDbl& ValMat, EigDbl& AdjMat) {
159+
AdjMat = EigDbl::NullaryExpr(
160+
VarMat.rows(), VarMat.cols(),
161+
val_adj_functor<const EigRev, EigDbl>(VarMat, ValMat));
162+
}
163+
164+
/**
165+
* Function applying the vi_val_functor to extract the varis and
166+
* and values of a given var matrix into separate matrices.
167+
*
168+
* @tparam EigVar type of the Eigen container of var.
169+
* @tparam EigDbl type of the Eigen container of doubles to be copied to.
170+
* @param[in] VarMat Input Eigen container of var.
171+
* @param[in] VariMat Output Eigen container of vari.
172+
* @param[in] ValMat Output Eigen container of values.
173+
*/
174+
template <typename EigVar, typename EigVari, typename EigDbl>
175+
inline void read_vi_val(const EigVar& VarMat, EigVari& VariMat,
176+
EigDbl& ValMat) {
177+
ValMat = EigDbl::NullaryExpr(
178+
VarMat.rows(), VarMat.cols(),
179+
vi_val_functor<const EigVar, EigVari>(VarMat, VariMat));
180+
}
181+
182+
/**
183+
* Function applying the vi_adj_functor to extract the varis and
184+
* and adjoints of a given var matrix into separate matrices.
185+
*
186+
* @tparam EigVar type of the Eigen container of var.
187+
* @tparam EigDbl type of the Eigen container of doubles to be copied to.
188+
* @param[in] VarMat Input Eigen container of var.
189+
* @param[in] VariMat Output Eigen container of vari.
190+
* @param[in] AdjMat Output Eigen container of adjoints.
191+
*/
192+
template <typename EigVar, typename EigVari, typename EigDbl>
193+
inline void read_vi_adj(const EigVar& VarMat, EigVari& VariMat,
194+
EigDbl& AdjMat) {
195+
AdjMat = EigDbl::NullaryExpr(
196+
VarMat.rows(), VarMat.cols(),
197+
vi_adj_functor<const EigVar, EigVari>(VarMat, VariMat));
198+
}
199+
200+
} // namespace math
201+
} // namespace stan
202+
#endif

0 commit comments

Comments
 (0)