Skip to content

Commit 89b43ef

Browse files
authored
Merge pull request #1847 from bstatcomp/generalize_fun_q_r
generalize */fun q-r
2 parents c02589c + 1d109f3 commit 89b43ef

15 files changed

Lines changed: 158 additions & 200 deletions

stan/math/fwd/fun/quad_form.hpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,43 @@ namespace math {
1414
* Symmetry of the resulting matrix is not guaranteed due to numerical
1515
* precision.
1616
*
17-
* @tparam Ta type of elements in the square matrix
18-
* @tparam Ra number of rows in the square matrix, can be Eigen::Dynamic
19-
* @tparam Ca number of columns in the square matrix, can be Eigen::Dynamic
20-
* @tparam Tb type of elements in the second matrix
21-
* @tparam Rb number of rows in the second matrix, can be Eigen::Dynamic
22-
* @tparam Cb number of columns in the second matrix, can be Eigen::Dynamic
17+
* @tparam EigMat1 type of the first (square) matrix
18+
* @tparam EigMat2 type of the second matrix
2319
*
2420
* @param A square matrix
2521
* @param B second matrix
2622
* @return The quadratic form, which is a symmetric matrix of size Cb.
2723
* @throws std::invalid_argument if A is not square, or if A cannot be
2824
* multiplied by B
2925
*/
30-
template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb,
31-
require_any_fvar_t<Ta, Tb>...>
32-
inline Eigen::Matrix<return_type_t<Ta, Tb>, Cb, Cb> quad_form(
33-
const Eigen::Matrix<Ta, Ra, Ca>& A, const Eigen::Matrix<Tb, Rb, Cb>& B) {
26+
template <typename EigMat1, typename EigMat2,
27+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
28+
require_not_eigen_col_vector_t<EigMat2>* = nullptr,
29+
require_any_vt_fvar<EigMat1, EigMat2>* = nullptr>
30+
inline promote_scalar_t<return_type_t<EigMat1, EigMat2>, EigMat2> quad_form(
31+
const EigMat1& A, const EigMat2& B) {
3432
check_square("quad_form", "A", A);
3533
check_multiplicable("quad_form", "A", A, "B", B);
36-
return multiply(transpose(B), multiply(A, B));
34+
return multiply(B.transpose(), multiply(A, B));
3735
}
3836

3937
/**
4038
* Return the quadratic form \f$ B^T A B \f$.
4139
*
42-
* @tparam Ta type of elements in the square matrix
43-
* @tparam Ra number of rows in the square matrix, can be Eigen::Dynamic
44-
* @tparam Ca number of columns in the square matrix, can be Eigen::Dynamic
45-
* @tparam Tb type of elements in the vector
46-
* @tparam Rb number of rows in the vector, can be Eigen::Dynamic
40+
* @tparam EigMat type of the matrix
41+
* @tparam ColVec type of the vector
4742
*
4843
* @param A square matrix
4944
* @param B vector
5045
* @return The quadratic form (a scalar).
5146
* @throws std::invalid_argument if A is not square, or if A cannot be
5247
* multiplied by B
5348
*/
54-
template <typename Ta, int Ra, int Ca, typename Tb, int Rb,
55-
require_any_fvar_t<Ta, Tb>...>
56-
inline return_type_t<Ta, Tb> quad_form(const Eigen::Matrix<Ta, Ra, Ca>& A,
57-
const Eigen::Matrix<Tb, Rb, 1>& B) {
49+
template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
50+
require_eigen_col_vector_t<ColVec>* = nullptr,
51+
require_any_vt_fvar<EigMat, ColVec>* = nullptr>
52+
inline return_type_t<EigMat, ColVec> quad_form(const EigMat& A,
53+
const ColVec& B) {
5854
check_square("quad_form", "A", A);
5955
check_multiplicable("quad_form", "A", A, "B", B);
6056
return dot_product(B, multiply(A, B));

stan/math/fwd/fun/quad_form_sym.hpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,45 @@ namespace math {
1313
*
1414
* Symmetry of the resulting matrix is guaranteed.
1515
*
16-
* @tparam TA type of elements in the symmetric matrix
17-
* @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
18-
* @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
19-
* @tparam TB type of elements in the second matrix
20-
* @tparam RB number of rows in the second matrix, can be Eigen::Dynamic
21-
* @tparam CB number of columns in the second matrix, can be Eigen::Dynamic
16+
* @tparam EigMat1 type of the first (symmetric) matrix
17+
* @tparam EigMat2 type of the second matrix
2218
*
2319
* @param A symmetric matrix
2420
* @param B second matrix
2521
* @return The quadratic form, which is a symmetric matrix of size CB.
2622
* @throws std::invalid_argument if A is not symmetric, or if A cannot be
2723
* multiplied by B
2824
*/
29-
template <typename TA, int RA, int CA, typename TB, int RB, int CB,
30-
require_any_fvar_t<TA, TB>...>
31-
inline Eigen::Matrix<return_type_t<TA, TB>, CB, CB> quad_form_sym(
32-
const Eigen::Matrix<TA, RA, CA>& A, const Eigen::Matrix<TB, RB, CB>& B) {
33-
using T = return_type_t<TA, TB>;
25+
template <typename EigMat1, typename EigMat2,
26+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
27+
require_not_eigen_col_vector_t<EigMat2>* = nullptr,
28+
require_any_vt_fvar<EigMat1, EigMat2>* = nullptr>
29+
inline promote_scalar_t<return_type_t<EigMat1, EigMat2>, EigMat2> quad_form_sym(
30+
const EigMat1& A, const EigMat2& B) {
31+
using T_ret = return_type_t<EigMat1, EigMat2>;
3432
check_multiplicable("quad_form_sym", "A", A, "B", B);
3533
check_symmetric("quad_form_sym", "A", A);
36-
Eigen::Matrix<T, CB, CB> ret(multiply(transpose(B), multiply(A, B)));
37-
return T(0.5) * (ret + transpose(ret));
34+
promote_scalar_t<T_ret, EigMat2> ret(multiply(B.transpose(), multiply(A, B)));
35+
return T_ret(0.5) * (ret + ret.transpose());
3836
}
3937

4038
/**
4139
* Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
4240
*
43-
* @tparam TA type of elements in the symmetric matrix
44-
* @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
45-
* @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
46-
* @tparam TB type of elements in the vector
47-
* @tparam RB number of rows in the vector, can be Eigen::Dynamic
41+
* @tparam EigMat type of the (symmetric) matrix
42+
* @tparam ColVec type of the vector
4843
*
4944
* @param A symmetric matrix
5045
* @param B vector
5146
* @return The quadratic form (a scalar).
5247
* @throws std::invalid_argument if A is not symmetric, or if A cannot be
5348
* multiplied by B
5449
*/
55-
template <typename TA, int RA, int CA, typename TB, int RB,
56-
require_any_fvar_t<TA, TB>...>
57-
inline return_type_t<TA, TB> quad_form_sym(const Eigen::Matrix<TA, RA, CA>& A,
58-
const Eigen::Matrix<TB, RB, 1>& B) {
50+
template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
51+
require_eigen_col_vector_t<ColVec>* = nullptr,
52+
require_any_vt_fvar<EigMat, ColVec>* = nullptr>
53+
inline return_type_t<EigMat, ColVec> quad_form_sym(const EigMat& A,
54+
const ColVec& B) {
5955
check_multiplicable("quad_form_sym", "A", A, "B", B);
6056
check_symmetric("quad_form_sym", "A", A);
6157
return dot_product(B, multiply(A, B));

stan/math/prim/fun/qr_Q.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_PRIM_FUN_QR_Q_HPP
22
#define STAN_MATH_PRIM_FUN_QR_Q_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
67
#include <algorithm>
@@ -11,14 +12,15 @@ namespace math {
1112
/**
1213
* Returns the orthogonal factor of the fat QR decomposition
1314
*
14-
* @tparam T type of elements in the matrix
15+
* @tparam EigMat type of the matrix
1516
* @param m Matrix.
1617
* @return Orthogonal matrix with maximal columns
1718
*/
18-
template <typename T>
19-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> qr_Q(
20-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
21-
using matrix_t = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
19+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
20+
Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic> qr_Q(
21+
const EigMat& m) {
22+
using matrix_t
23+
= Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic>;
2224
check_nonzero_size("qr_Q", "m", m);
2325
Eigen::HouseholderQR<matrix_t> qr(m.rows(), m.cols());
2426
qr.compute(m);

stan/math/prim/fun/qr_R.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ namespace math {
1111
/**
1212
* Returns the upper triangular factor of the fat QR decomposition
1313
*
14-
* @tparam T type of elements in the matrix
14+
* @tparam EigMat type of the matrix
1515
* @param m Matrix.
1616
* @return Upper triangular matrix with maximal rows
1717
*/
18-
template <typename T>
19-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> qr_R(
20-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
21-
using matrix_t = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
18+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
19+
Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic> qr_R(
20+
const EigMat& m) {
21+
using matrix_t
22+
= Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic>;
2223
check_nonzero_size("qr_R", "m", m);
2324
Eigen::HouseholderQR<matrix_t> qr(m.rows(), m.cols());
2425
qr.compute(m);

stan/math/prim/fun/qr_thin_Q.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ namespace math {
1111
/**
1212
* Returns the orthogonal factor of the thin QR decomposition
1313
*
14-
* @tparam T type of elements in the matrix
14+
* @tparam EigMat type of the matrix
1515
* @param m Matrix.
1616
* @return Orthogonal matrix with minimal columns
1717
*/
18-
template <typename T>
19-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> qr_thin_Q(
20-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
21-
using matrix_t = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
18+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
19+
Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic> qr_thin_Q(
20+
const EigMat& m) {
21+
using matrix_t
22+
= Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic>;
2223
check_nonzero_size("qr_thin_Q", "m", m);
2324
Eigen::HouseholderQR<matrix_t> qr(m.rows(), m.cols());
2425
qr.compute(m);

stan/math/prim/fun/qr_thin_R.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ namespace math {
1111
/**
1212
* Returns the upper triangular factor of the thin QR decomposition
1313
*
14-
* @tparam T type of elements in the matrix
14+
* @tparam EigMat type of the matrix
1515
* @param m Matrix.
1616
* @return Upper triangular matrix with minimal rows
1717
*/
18-
template <typename T>
19-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> qr_thin_R(
20-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
21-
using matrix_t = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
18+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
19+
Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic> qr_thin_R(
20+
const EigMat& m) {
21+
using matrix_t
22+
= Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, Eigen::Dynamic>;
2223
check_nonzero_size("qr_thin_R", "m", m);
2324
Eigen::HouseholderQR<matrix_t> qr(m.rows(), m.cols());
2425
qr.compute(m);

stan/math/prim/fun/quad_form.hpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,23 @@ namespace math {
1313
* Symmetry of the resulting matrix is not guaranteed due to numerical
1414
* precision.
1515
*
16-
* @tparam RA number of rows in the square matrix, can be Eigen::Dynamic
17-
* @tparam CA number of columns in the square matrix, can be Eigen::Dynamic
18-
* @tparam RB number of rows in the second matrix, can be Eigen::Dynamic
19-
* @tparam CB number of columns in the second matrix, can be Eigen::Dynamic
20-
* @tparam T type of elements
16+
* @tparam EigMat1 type of the first (square) matrix
17+
* @tparam EigMat2 type of the second matrix
2118
*
2219
* @param A square matrix
2320
* @param B second matrix
24-
* @return The quadratic form, which is a symmetric matrix of size CB.
21+
* @return The quadratic form, which is a symmetric matrix.
2522
* @throws std::invalid_argument if A is not square, or if A cannot be
2623
* multiplied by B
2724
*/
28-
template <int RA, int CA, int RB, int CB, typename T>
29-
inline Eigen::Matrix<T, CB, CB> quad_form(const Eigen::Matrix<T, RA, CA>& A,
30-
const Eigen::Matrix<T, RB, CB>& B) {
25+
template <typename EigMat1, typename EigMat2,
26+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
27+
require_not_eigen_col_vector_t<EigMat2>* = nullptr,
28+
require_vt_same<EigMat1, EigMat2>* = nullptr,
29+
require_all_vt_arithmetic<EigMat1, EigMat2>* = nullptr>
30+
inline Eigen::Matrix<value_type_t<EigMat2>, EigMat2::ColsAtCompileTime,
31+
EigMat2::ColsAtCompileTime>
32+
quad_form(const EigMat1& A, const EigMat2& B) {
3133
check_square("quad_form", "A", A);
3234
check_multiplicable("quad_form", "A", A, "B", B);
3335
return B.transpose() * A * B;
@@ -36,20 +38,20 @@ inline Eigen::Matrix<T, CB, CB> quad_form(const Eigen::Matrix<T, RA, CA>& A,
3638
/**
3739
* Return the quadratic form \f$ B^T A B \f$.
3840
*
39-
* @tparam RA number of rows in the square matrix, can be Eigen::Dynamic
40-
* @tparam CA number of columns in the square matrix, can be Eigen::Dynamic
41-
* @tparam RB number of rows in the vector, can be Eigen::Dynamic
42-
* @tparam T type of elements
41+
* @tparam EigMat type of the matrix
42+
* @tparam ColVec type of the vector
4343
*
4444
* @param A square matrix
4545
* @param B vector
4646
* @return The quadratic form (a scalar).
4747
* @throws std::invalid_argument if A is not square, or if A cannot be
4848
* multiplied by B
4949
*/
50-
template <int RA, int CA, int RB, typename T>
51-
inline T quad_form(const Eigen::Matrix<T, RA, CA>& A,
52-
const Eigen::Matrix<T, RB, 1>& B) {
50+
template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
51+
require_eigen_col_vector_t<ColVec>* = nullptr,
52+
require_vt_same<EigMat, ColVec>* = nullptr,
53+
require_all_vt_arithmetic<EigMat, ColVec>* = nullptr>
54+
inline value_type_t<EigMat> quad_form(const EigMat& A, const ColVec& B) {
5355
check_square("quad_form", "A", A);
5456
check_multiplicable("quad_form", "A", A, "B", B);
5557
return B.dot(A * B);

stan/math/prim/fun/quad_form_diag.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
namespace stan {
88
namespace math {
99

10-
template <typename T1, typename T2, int R, int C>
11-
inline Eigen::Matrix<return_type_t<T1, T2>, Eigen::Dynamic, Eigen::Dynamic>
12-
quad_form_diag(const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic>& mat,
13-
const Eigen::Matrix<T2, R, C>& vec) {
14-
check_vector("quad_form_diag", "vec", vec);
10+
template <typename EigMat, typename EigVec, require_eigen_t<EigMat>* = nullptr,
11+
require_eigen_vector_t<EigVec>* = nullptr>
12+
inline Eigen::Matrix<return_type_t<EigMat, EigVec>, Eigen::Dynamic,
13+
Eigen::Dynamic>
14+
quad_form_diag(const EigMat& mat, const EigVec& vec) {
1515
check_square("quad_form_diag", "mat", mat);
1616
check_size_match("quad_form_diag", "rows of mat", mat.rows(), "size of vec",
1717
vec.size());

stan/math/prim/fun/quad_form_sym.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,44 +12,44 @@ namespace math {
1212
*
1313
* Symmetry of the resulting matrix is guaranteed.
1414
*
15-
* @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
16-
* @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
17-
* @tparam RB number of rows in the second matrix, can be Eigen::Dynamic
18-
* @tparam CB number of columns in the second matrix, can be Eigen::Dynamic
19-
* @tparam TA type of elements
15+
* @tparam EigMat1 type of the first (symmetric) matrix
16+
* @tparam EigMat2 type of the second matrix
2017
*
2118
* @param A symmetric matrix
2219
* @param B second matrix
23-
* @return The quadratic form, which is a symmetric matrix of size CB.
20+
* @return The quadratic form, which is a symmetric matrix.
2421
* @throws std::invalid_argument if A is not symmetric, or if A cannot be
2522
* multiplied by B
2623
*/
27-
template <int RA, int CA, int RB, int CB, typename T>
28-
inline Eigen::Matrix<T, CB, CB> quad_form_sym(
29-
const Eigen::Matrix<T, RA, CA>& A, const Eigen::Matrix<T, RB, CB>& B) {
24+
template <typename EigMat1, typename EigMat2,
25+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
26+
require_not_eigen_col_vector_t<EigMat2>* = nullptr,
27+
require_vt_same<EigMat1, EigMat2>* = nullptr,
28+
require_all_vt_arithmetic<EigMat1, EigMat2>* = nullptr>
29+
inline plain_type_t<EigMat2> quad_form_sym(const EigMat1& A, const EigMat2& B) {
3030
check_multiplicable("quad_form_sym", "A", A, "B", B);
3131
check_symmetric("quad_form_sym", "A", A);
32-
Eigen::Matrix<T, CB, CB> ret(B.transpose() * A * B);
33-
return T(0.5) * (ret + ret.transpose());
32+
plain_type_t<EigMat2> ret(B.transpose() * A * B);
33+
return value_type_t<EigMat2>(0.5) * (ret + ret.transpose());
3434
}
3535

3636
/**
3737
* Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
3838
*
39-
* @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
40-
* @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
41-
* @tparam RB number of rows in the vector, can be Eigen::Dynamic
42-
* @tparam T type of elements
39+
* @tparam EigMat type of the (symmetric) matrix
40+
* @tparam ColVec type of the vector
4341
*
4442
* @param A symmetric matrix
4543
* @param B vector
4644
* @return The quadratic form (a scalar).
4745
* @throws std::invalid_argument if A is not symmetric, or if A cannot be
4846
* multiplied by B
4947
*/
50-
template <int RA, int CA, int RB, typename T>
51-
inline T quad_form_sym(const Eigen::Matrix<T, RA, CA>& A,
52-
const Eigen::Matrix<T, RB, 1>& B) {
48+
template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
49+
require_eigen_col_vector_t<ColVec>* = nullptr,
50+
require_vt_same<EigMat, ColVec>* = nullptr,
51+
require_all_vt_arithmetic<EigMat, ColVec>* = nullptr>
52+
inline value_type_t<EigMat> quad_form_sym(const EigMat& A, const ColVec& B) {
5353
check_multiplicable("quad_form_sym", "A", A, "B", B);
5454
check_symmetric("quad_form_sym", "A", A);
5555
return B.dot(A * B);

stan/math/prim/fun/rank.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@ namespace math {
1616
* @return number of components of v less than v[s].
1717
* @throw std::out_of_range if s is out of range.
1818
*/
19-
template <typename C>
19+
template <typename C, require_container_t<C>* = nullptr>
2020
inline int rank(const C& v, int s) {
2121
check_range("rank", "v", v.size(), s);
2222
--s; // adjust for indexing by one
23-
int count = 0;
24-
for (index_type_t<C> i = 0; i < v.size(); ++i) {
25-
if (v[i] < v[s]) {
26-
++count;
27-
}
28-
}
29-
return count;
23+
return apply_vector_unary<C>::reduce(v, [s](const auto& vec) {
24+
return (vec.array() < vec.coeff(s)).template cast<int>().sum();
25+
});
3026
}
3127

3228
} // namespace math

0 commit comments

Comments
 (0)