Skip to content

Commit fdf7f70

Browse files
authored
Merge pull request #1845 from bstatcomp/generalize_prim_fun_m_p
Generalize */fun from m to p
2 parents bf46e93 + 6d10019 commit fdf7f70

25 files changed

Lines changed: 360 additions & 416 deletions

stan/math/fwd/fun/mdivide_left_ldlt.hpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,29 @@ namespace math {
1515
*
1616
* @tparam R1 number of rows in the LDLT_factor, can be Eigen::Dynamic
1717
* @tparam C1 number of columns in the LDLT_factor, can be Eigen::Dynamic
18-
* @tparam R2 number of rows in the right-hand side matrix, can be
19-
* Eigen::Dynamic
20-
* @tparam C2 number of columns in the right-hand side matrix, can be
21-
* Eigen::Dynamic
22-
* @tparam T2 type of elements in the right-hand side matrix or vector
18+
* @tparam EigMat type of the right-hand side matrix or vector
2319
*
2420
* @param A LDLT_factor
2521
* @param b right-hand side matrix or vector
2622
* @return x = b A^-1, solution of the linear system.
2723
* @throws std::domain_error if rows of b don't match the size of A.
2824
*/
29-
template <int R1, int C1, int R2, int C2, typename T2>
30-
inline Eigen::Matrix<fvar<T2>, R1, C2> mdivide_left_ldlt(
31-
const LDLT_factor<double, R1, C1> &A,
32-
const Eigen::Matrix<fvar<T2>, R2, C2> &b) {
25+
template <int R1, int C1, typename EigMat,
26+
require_eigen_vt<is_fvar, EigMat>* = nullptr>
27+
inline Eigen::Matrix<value_type_t<EigMat>, R1, EigMat::ColsAtCompileTime>
28+
mdivide_left_ldlt(const LDLT_factor<double, R1, C1>& A, const EigMat& b) {
29+
using T = typename value_type_t<EigMat>::Scalar;
30+
constexpr int R2 = EigMat::RowsAtCompileTime;
31+
constexpr int C2 = EigMat::ColsAtCompileTime;
3332
check_multiplicable("mdivide_left_ldlt", "A", A, "b", b);
3433

35-
Eigen::Matrix<T2, R2, C2> b_val(b.rows(), b.cols());
36-
Eigen::Matrix<T2, R2, C2> b_der(b.rows(), b.cols());
34+
const Eigen::Ref<const plain_type_t<EigMat>>& b_ref = b;
35+
Eigen::Matrix<T, R2, C2> b_val(b.rows(), b.cols());
36+
Eigen::Matrix<T, R2, C2> b_der(b.rows(), b.cols());
3737
for (int j = 0; j < b.cols(); j++) {
3838
for (int i = 0; i < b.rows(); i++) {
39-
b_val(i, j) = b(i, j).val_;
40-
b_der(i, j) = b(i, j).d_;
39+
b_val.coeffRef(i, j) = b_ref.coeff(i, j).val_;
40+
b_der.coeffRef(i, j) = b_ref.coeff(i, j).d_;
4141
}
4242
}
4343

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,101 +13,114 @@
1313
namespace stan {
1414
namespace math {
1515

16-
template <typename T, int R1, int C1, int R2, int C2>
17-
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
18-
const Eigen::Matrix<fvar<T>, R1, C1> &A,
19-
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
16+
template <typename EigMat1, typename EigMat2,
17+
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr,
18+
require_vt_same<EigMat1, EigMat2>* = nullptr>
19+
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
20+
EigMat2::ColsAtCompileTime>
21+
mdivide_right(const EigMat1& A, const EigMat2& b) {
22+
using T = typename value_type_t<EigMat1>::Scalar;
23+
constexpr int R1 = EigMat1::RowsAtCompileTime;
24+
constexpr int C1 = EigMat1::ColsAtCompileTime;
25+
constexpr int R2 = EigMat2::RowsAtCompileTime;
26+
constexpr int C2 = EigMat2::ColsAtCompileTime;
27+
2028
check_square("mdivide_right", "b", b);
2129
check_multiplicable("mdivide_right", "A", A, "b", b);
2230
if (b.size() == 0) {
2331
return {A.rows(), 0};
2432
}
2533

26-
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
27-
Eigen::Matrix<T, R1, C2> deriv_A_mult_inv_b(A.rows(), b.cols());
28-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
2934
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
3035
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
3136
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
3237
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
3338

39+
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
3440
for (int j = 0; j < A.cols(); j++) {
3541
for (int i = 0; i < A.rows(); i++) {
36-
val_A(i, j) = A(i, j).val_;
37-
deriv_A(i, j) = A(i, j).d_;
42+
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
43+
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
3844
}
3945
}
4046

47+
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
4148
for (int j = 0; j < b.cols(); j++) {
4249
for (int i = 0; i < b.rows(); i++) {
43-
val_b(i, j) = b(i, j).val_;
44-
deriv_b(i, j) = b(i, j).d_;
50+
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
51+
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
4552
}
4653
}
4754

48-
A_mult_inv_b = mdivide_right(val_A, val_b);
49-
deriv_A_mult_inv_b = mdivide_right(deriv_A, val_b);
50-
deriv_b_mult_inv_b = mdivide_right(deriv_b, val_b);
51-
52-
Eigen::Matrix<T, R1, C2> deriv(A.rows(), b.cols());
53-
deriv = deriv_A_mult_inv_b - multiply(A_mult_inv_b, deriv_b_mult_inv_b);
55+
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(val_A, val_b);
5456

55-
return to_fvar(A_mult_inv_b, deriv);
57+
return to_fvar(A_mult_inv_b,
58+
mdivide_right(deriv_A, val_b)
59+
- A_mult_inv_b * mdivide_right(deriv_b, val_b));
5660
}
5761

58-
template <typename T, int R1, int C1, int R2, int C2>
59-
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
60-
const Eigen::Matrix<fvar<T>, R1, C1> &A,
61-
const Eigen::Matrix<double, R2, C2> &b) {
62+
template <typename EigMat1, typename EigMat2,
63+
require_eigen_vt<is_fvar, EigMat1>* = nullptr,
64+
require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr>
65+
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
66+
EigMat2::ColsAtCompileTime>
67+
mdivide_right(const EigMat1& A, const EigMat2& b) {
68+
using T = typename value_type_t<EigMat1>::Scalar;
69+
constexpr int R1 = EigMat1::RowsAtCompileTime;
70+
constexpr int C1 = EigMat1::ColsAtCompileTime;
71+
6272
check_square("mdivide_right", "b", b);
6373
check_multiplicable("mdivide_right", "A", A, "b", b);
6474
if (b.size() == 0) {
6575
return {A.rows(), 0};
6676
}
6777

68-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
6978
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
7079
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
7180

81+
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
7282
for (int j = 0; j < A.cols(); j++) {
7383
for (int i = 0; i < A.rows(); i++) {
74-
val_A(i, j) = A(i, j).val_;
75-
deriv_A(i, j) = A(i, j).d_;
84+
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
85+
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
7686
}
7787
}
7888

7989
return to_fvar(mdivide_right(val_A, b), mdivide_right(deriv_A, b));
8090
}
8191

82-
template <typename T, int R1, int C1, int R2, int C2>
83-
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
84-
const Eigen::Matrix<double, R1, C1> &A,
85-
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
92+
template <typename EigMat1, typename EigMat2,
93+
require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr,
94+
require_eigen_vt<is_fvar, EigMat2>* = nullptr>
95+
inline Eigen::Matrix<value_type_t<EigMat2>, EigMat1::RowsAtCompileTime,
96+
EigMat2::ColsAtCompileTime>
97+
mdivide_right(const EigMat1& A, const EigMat2& b) {
98+
using T = typename value_type_t<EigMat2>::Scalar;
99+
constexpr int R1 = EigMat1::RowsAtCompileTime;
100+
constexpr int C1 = EigMat1::ColsAtCompileTime;
101+
constexpr int R2 = EigMat2::RowsAtCompileTime;
102+
constexpr int C2 = EigMat2::ColsAtCompileTime;
103+
86104
check_square("mdivide_right", "b", b);
87105
check_multiplicable("mdivide_right", "A", A, "b", b);
88106
if (b.size() == 0) {
89107
return {A.rows(), 0};
90108
}
91109

92-
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
93-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
94110
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
95111
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
96112

113+
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
97114
for (int j = 0; j < b.cols(); j++) {
98115
for (int i = 0; i < b.rows(); i++) {
99-
val_b(i, j) = b(i, j).val_;
100-
deriv_b(i, j) = b(i, j).d_;
116+
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
117+
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
101118
}
102119
}
103120

104-
A_mult_inv_b = mdivide_right(A, val_b);
105-
deriv_b_mult_inv_b = mdivide_right(deriv_b, val_b);
106-
107-
Eigen::Matrix<T, R1, C2> deriv(A.rows(), b.cols());
108-
deriv = -multiply(A_mult_inv_b, deriv_b_mult_inv_b);
121+
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
109122

110-
return to_fvar(A_mult_inv_b, deriv);
123+
return to_fvar(A_mult_inv_b, -A_mult_inv_b * mdivide_right(deriv_b, val_b));
111124
}
112125

113126
} // namespace math

stan/math/fwd/fun/mdivide_right_tri_low.hpp

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,94 +11,103 @@
1111
namespace stan {
1212
namespace math {
1313

14-
template <typename T, int R1, int C1, int R2, int C2>
15-
inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_right_tri_low(
16-
const Eigen::Matrix<fvar<T>, R1, C1> &A,
17-
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
14+
template <typename EigMat1, typename EigMat2,
15+
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr,
16+
require_vt_same<EigMat1, EigMat2>* = nullptr>
17+
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
18+
EigMat2::ColsAtCompileTime>
19+
mdivide_right_tri_low(const EigMat1& A, const EigMat2& b) {
20+
using T = typename value_type_t<EigMat1>::Scalar;
21+
constexpr int R1 = EigMat1::RowsAtCompileTime;
22+
constexpr int C1 = EigMat1::ColsAtCompileTime;
23+
constexpr int R2 = EigMat2::RowsAtCompileTime;
24+
constexpr int C2 = EigMat2::ColsAtCompileTime;
25+
1826
check_square("mdivide_right_tri_low", "b", b);
1927
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
2028
if (b.size() == 0) {
2129
return {A.rows(), 0};
2230
}
2331

24-
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
25-
Eigen::Matrix<T, R1, C2> deriv_A_mult_inv_b(A.rows(), b.cols());
26-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
2732
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
2833
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
2934
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
3035
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
3136
val_b.setZero();
3237
deriv_b.setZero();
3338

34-
for (size_type j = 0; j < A.cols(); j++) {
35-
for (size_type i = 0; i < A.rows(); i++) {
36-
val_A(i, j) = A(i, j).val_;
37-
deriv_A(i, j) = A(i, j).d_;
39+
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
40+
for (int j = 0; j < A.cols(); j++) {
41+
for (int i = 0; i < A.rows(); i++) {
42+
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
43+
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
3844
}
3945
}
4046

41-
for (size_type j = 0; j < b.cols(); j++) {
42-
for (size_type i = j; i < b.rows(); i++) {
43-
val_b(i, j) = b(i, j).val_;
44-
deriv_b(i, j) = b(i, j).d_;
47+
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
48+
for (int j = 0; j < b.cols(); j++) {
49+
for (int i = j; i < b.rows(); i++) {
50+
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
51+
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
4552
}
4653
}
4754

48-
A_mult_inv_b = mdivide_right(val_A, val_b);
49-
deriv_A_mult_inv_b = mdivide_right(deriv_A, val_b);
50-
deriv_b_mult_inv_b = mdivide_right(deriv_b, val_b);
51-
52-
Eigen::Matrix<T, R1, C2> deriv(A.rows(), b.cols());
53-
deriv = deriv_A_mult_inv_b - multiply(A_mult_inv_b, deriv_b_mult_inv_b);
54-
55-
return to_fvar(A_mult_inv_b, deriv);
55+
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(val_A, val_b);
56+
return to_fvar(A_mult_inv_b,
57+
mdivide_right(deriv_A, val_b)
58+
- A_mult_inv_b * mdivide_right(deriv_b, val_b));
5659
}
5760

58-
template <typename T, int R1, int C1, int R2, int C2>
59-
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
60-
const Eigen::Matrix<fvar<T>, R1, C1> &A,
61-
const Eigen::Matrix<double, R2, C2> &b) {
61+
template <typename EigMat1, typename EigMat2,
62+
require_eigen_vt<is_fvar, EigMat1>* = nullptr,
63+
require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr>
64+
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
65+
EigMat2::ColsAtCompileTime>
66+
mdivide_right_tri_low(const EigMat1& A, const EigMat2& b) {
67+
using T = typename value_type_t<EigMat1>::Scalar;
68+
constexpr int R1 = EigMat1::RowsAtCompileTime;
69+
constexpr int C1 = EigMat1::ColsAtCompileTime;
70+
6271
check_square("mdivide_right_tri_low", "b", b);
6372
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
6473
if (b.size() == 0) {
6574
return {A.rows(), 0};
6675
}
6776

68-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
6977
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
7078
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
71-
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
72-
val_b.setZero();
7379

80+
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
7481
for (int j = 0; j < A.cols(); j++) {
7582
for (int i = 0; i < A.rows(); i++) {
76-
val_A(i, j) = A(i, j).val_;
77-
deriv_A(i, j) = A(i, j).d_;
83+
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
84+
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
7885
}
7986
}
8087

81-
for (size_type j = 0; j < b.cols(); j++) {
82-
for (size_type i = j; i < b.rows(); i++) {
83-
val_b(i, j) = b(i, j);
84-
}
85-
}
88+
plain_type_t<EigMat2> val_b = b.template triangularView<Eigen::Lower>();
8689

8790
return to_fvar(mdivide_right(val_A, val_b), mdivide_right(deriv_A, val_b));
8891
}
8992

90-
template <typename T, int R1, int C1, int R2, int C2>
91-
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
92-
const Eigen::Matrix<double, R1, C1> &A,
93-
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
93+
template <typename EigMat1, typename EigMat2,
94+
require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr,
95+
require_eigen_vt<is_fvar, EigMat2>* = nullptr>
96+
inline Eigen::Matrix<value_type_t<EigMat2>, EigMat1::RowsAtCompileTime,
97+
EigMat2::ColsAtCompileTime>
98+
mdivide_right_tri_low(const EigMat1& A, const EigMat2& b) {
99+
using T = typename value_type_t<EigMat2>::Scalar;
100+
constexpr int R1 = EigMat1::RowsAtCompileTime;
101+
constexpr int C1 = EigMat1::ColsAtCompileTime;
102+
constexpr int R2 = EigMat2::RowsAtCompileTime;
103+
constexpr int C2 = EigMat2::ColsAtCompileTime;
94104
check_square("mdivide_right_tri_low", "b", b);
95105
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
96106
if (b.size() == 0) {
97107
return {A.rows(), 0};
98108
}
99109

100110
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
101-
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
102111
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
103112
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
104113
val_b.setZero();
@@ -112,12 +121,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
112121
}
113122

114123
A_mult_inv_b = mdivide_right(A, val_b);
115-
deriv_b_mult_inv_b = mdivide_right(deriv_b, val_b);
116-
117-
Eigen::Matrix<T, R1, C2> deriv(A.rows(), b.cols());
118-
deriv = -multiply(A_mult_inv_b, deriv_b_mult_inv_b);
119124

120-
return to_fvar(A_mult_inv_b, deriv);
125+
return to_fvar(A_mult_inv_b,
126+
-multiply(A_mult_inv_b, mdivide_right(deriv_b, val_b)));
121127
}
122128

123129
} // namespace math

stan/math/fwd/fun/multiply_lower_tri_self_transpose.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,22 @@
99
namespace stan {
1010
namespace math {
1111

12-
template <typename T, int R, int C>
13-
inline Eigen::Matrix<fvar<T>, R, R> multiply_lower_tri_self_transpose(
14-
const Eigen::Matrix<fvar<T>, R, C>& m) {
12+
template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr>
13+
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
14+
EigMat::RowsAtCompileTime>
15+
multiply_lower_tri_self_transpose(const EigMat& m) {
1516
if (m.rows() == 0) {
1617
return {};
1718
}
18-
Eigen::Matrix<fvar<T>, R, C> L(m.rows(), m.cols());
19+
plain_type_t<EigMat> L(m.rows(), m.cols());
1920
L.setZero();
2021

2122
for (size_type i = 0; i < m.rows(); i++) {
2223
for (size_type j = 0; (j < i + 1) && (j < m.cols()); j++) {
23-
L(i, j) = m(i, j);
24+
L.coeffRef(i, j) = m.coeff(i, j);
2425
}
2526
}
26-
return multiply(L, transpose(L));
27+
return multiply(L, L.transpose());
2728
}
2829

2930
} // namespace math

0 commit comments

Comments
 (0)