1111namespace stan {
1212namespace math {
1313
14- template <typename T, int R1, int C1, int R2, int C2>
15- inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_right_tri_low (
16- const Eigen::Matrix<fvar<T>, R1, C1> &A,
17- const Eigen::Matrix<fvar<T>, R2, C2> &b) {
14+ template <typename EigMat1, typename EigMat2,
15+ require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr ,
16+ require_vt_same<EigMat1, EigMat2>* = nullptr >
17+ inline Eigen::Matrix<value_type_t <EigMat1>, EigMat1::RowsAtCompileTime,
18+ EigMat2::ColsAtCompileTime>
19+ mdivide_right_tri_low (const EigMat1& A, const EigMat2& b) {
20+ using T = typename value_type_t <EigMat1>::Scalar;
21+ constexpr int R1 = EigMat1::RowsAtCompileTime;
22+ constexpr int C1 = EigMat1::ColsAtCompileTime;
23+ constexpr int R2 = EigMat2::RowsAtCompileTime;
24+ constexpr int C2 = EigMat2::ColsAtCompileTime;
25+
1826 check_square (" mdivide_right_tri_low" , " b" , b);
1927 check_multiplicable (" mdivide_right_tri_low" , " A" , A, " b" , b);
2028 if (b.size () == 0 ) {
2129 return {A.rows (), 0 };
2230 }
2331
24- Eigen::Matrix<T, R1, C2> A_mult_inv_b (A.rows (), b.cols ());
25- Eigen::Matrix<T, R1, C2> deriv_A_mult_inv_b (A.rows (), b.cols ());
26- Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b (b.rows (), b.cols ());
2732 Eigen::Matrix<T, R1, C1> val_A (A.rows (), A.cols ());
2833 Eigen::Matrix<T, R1, C1> deriv_A (A.rows (), A.cols ());
2934 Eigen::Matrix<T, R2, C2> val_b (b.rows (), b.cols ());
3035 Eigen::Matrix<T, R2, C2> deriv_b (b.rows (), b.cols ());
3136 val_b.setZero ();
3237 deriv_b.setZero ();
3338
34- for (size_type j = 0 ; j < A.cols (); j++) {
35- for (size_type i = 0 ; i < A.rows (); i++) {
36- val_A (i, j) = A (i, j).val_ ;
37- deriv_A (i, j) = A (i, j).d_ ;
39+ const Eigen::Ref<const plain_type_t <EigMat1>>& A_ref = A;
40+ for (int j = 0 ; j < A.cols (); j++) {
41+ for (int i = 0 ; i < A.rows (); i++) {
42+ val_A.coeffRef (i, j) = A_ref.coeff (i, j).val_ ;
43+ deriv_A.coeffRef (i, j) = A_ref.coeff (i, j).d_ ;
3844 }
3945 }
4046
41- for (size_type j = 0 ; j < b.cols (); j++) {
42- for (size_type i = j; i < b.rows (); i++) {
43- val_b (i, j) = b (i, j).val_ ;
44- deriv_b (i, j) = b (i, j).d_ ;
47+ const Eigen::Ref<const plain_type_t <EigMat2>>& b_ref = b;
48+ for (int j = 0 ; j < b.cols (); j++) {
49+ for (int i = j; i < b.rows (); i++) {
50+ val_b.coeffRef (i, j) = b_ref.coeff (i, j).val_ ;
51+ deriv_b.coeffRef (i, j) = b_ref.coeff (i, j).d_ ;
4552 }
4653 }
4754
48- A_mult_inv_b = mdivide_right (val_A, val_b);
49- deriv_A_mult_inv_b = mdivide_right (deriv_A, val_b);
50- deriv_b_mult_inv_b = mdivide_right (deriv_b, val_b);
51-
52- Eigen::Matrix<T, R1, C2> deriv (A.rows (), b.cols ());
53- deriv = deriv_A_mult_inv_b - multiply (A_mult_inv_b, deriv_b_mult_inv_b);
54-
55- return to_fvar (A_mult_inv_b, deriv);
55+ Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right (val_A, val_b);
56+ return to_fvar (A_mult_inv_b,
57+ mdivide_right (deriv_A, val_b)
58+ - A_mult_inv_b * mdivide_right (deriv_b, val_b));
5659}
5760
58- template <typename T, int R1, int C1, int R2, int C2>
59- inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low (
60- const Eigen::Matrix<fvar<T>, R1, C1> &A,
61- const Eigen::Matrix<double , R2, C2> &b) {
61+ template <typename EigMat1, typename EigMat2,
62+ require_eigen_vt<is_fvar, EigMat1>* = nullptr ,
63+ require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr >
64+ inline Eigen::Matrix<value_type_t <EigMat1>, EigMat1::RowsAtCompileTime,
65+ EigMat2::ColsAtCompileTime>
66+ mdivide_right_tri_low (const EigMat1& A, const EigMat2& b) {
67+ using T = typename value_type_t <EigMat1>::Scalar;
68+ constexpr int R1 = EigMat1::RowsAtCompileTime;
69+ constexpr int C1 = EigMat1::ColsAtCompileTime;
70+
6271 check_square (" mdivide_right_tri_low" , " b" , b);
6372 check_multiplicable (" mdivide_right_tri_low" , " A" , A, " b" , b);
6473 if (b.size () == 0 ) {
6574 return {A.rows (), 0 };
6675 }
6776
68- Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b (b.rows (), b.cols ());
6977 Eigen::Matrix<T, R1, C1> val_A (A.rows (), A.cols ());
7078 Eigen::Matrix<T, R1, C1> deriv_A (A.rows (), A.cols ());
71- Eigen::Matrix<T, R2, C2> val_b (b.rows (), b.cols ());
72- val_b.setZero ();
7379
80+ const Eigen::Ref<const plain_type_t <EigMat1>>& A_ref = A;
7481 for (int j = 0 ; j < A.cols (); j++) {
7582 for (int i = 0 ; i < A.rows (); i++) {
76- val_A (i, j) = A (i, j).val_ ;
77- deriv_A (i, j) = A (i, j).d_ ;
83+ val_A. coeffRef (i, j) = A_ref. coeff (i, j).val_ ;
84+ deriv_A. coeffRef (i, j) = A_ref. coeff (i, j).d_ ;
7885 }
7986 }
8087
81- for (size_type j = 0 ; j < b.cols (); j++) {
82- for (size_type i = j; i < b.rows (); i++) {
83- val_b (i, j) = b (i, j);
84- }
85- }
88+ plain_type_t <EigMat2> val_b = b.template triangularView <Eigen::Lower>();
8689
8790 return to_fvar (mdivide_right (val_A, val_b), mdivide_right (deriv_A, val_b));
8891}
8992
90- template <typename T, int R1, int C1, int R2, int C2>
91- inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low (
92- const Eigen::Matrix<double , R1, C1> &A,
93- const Eigen::Matrix<fvar<T>, R2, C2> &b) {
93+ template <typename EigMat1, typename EigMat2,
94+ require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr ,
95+ require_eigen_vt<is_fvar, EigMat2>* = nullptr >
96+ inline Eigen::Matrix<value_type_t <EigMat2>, EigMat1::RowsAtCompileTime,
97+ EigMat2::ColsAtCompileTime>
98+ mdivide_right_tri_low (const EigMat1& A, const EigMat2& b) {
99+ using T = typename value_type_t <EigMat2>::Scalar;
100+ constexpr int R1 = EigMat1::RowsAtCompileTime;
101+ constexpr int C1 = EigMat1::ColsAtCompileTime;
102+ constexpr int R2 = EigMat2::RowsAtCompileTime;
103+ constexpr int C2 = EigMat2::ColsAtCompileTime;
94104 check_square (" mdivide_right_tri_low" , " b" , b);
95105 check_multiplicable (" mdivide_right_tri_low" , " A" , A, " b" , b);
96106 if (b.size () == 0 ) {
97107 return {A.rows (), 0 };
98108 }
99109
100110 Eigen::Matrix<T, R1, C2> A_mult_inv_b (A.rows (), b.cols ());
101- Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b (b.rows (), b.cols ());
102111 Eigen::Matrix<T, R2, C2> val_b (b.rows (), b.cols ());
103112 Eigen::Matrix<T, R2, C2> deriv_b (b.rows (), b.cols ());
104113 val_b.setZero ();
@@ -112,12 +121,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
112121 }
113122
114123 A_mult_inv_b = mdivide_right (A, val_b);
115- deriv_b_mult_inv_b = mdivide_right (deriv_b, val_b);
116-
117- Eigen::Matrix<T, R1, C2> deriv (A.rows (), b.cols ());
118- deriv = -multiply (A_mult_inv_b, deriv_b_mult_inv_b);
119124
120- return to_fvar (A_mult_inv_b, deriv);
125+ return to_fvar (A_mult_inv_b,
126+ -multiply (A_mult_inv_b, mdivide_right (deriv_b, val_b)));
121127}
122128
123129} // namespace math
0 commit comments