@@ -13,49 +13,46 @@ namespace math {
1313 *
1414 * Symmetry of the resulting matrix is guaranteed.
1515 *
16- * @tparam TA type of elements in the symmetric matrix
17- * @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
18- * @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
19- * @tparam TB type of elements in the second matrix
20- * @tparam RB number of rows in the second matrix, can be Eigen::Dynamic
21- * @tparam CB number of columns in the second matrix, can be Eigen::Dynamic
16+ * @tparam EigMat1 type of the first (symmetric) matrix
17+ * @tparam EigMat2 type of the second matrix
2218 *
2319 * @param A symmetric matrix
2420 * @param B second matrix
2521 * @return The quadratic form, which is a symmetric matrix of size CB.
2622 * @throws std::invalid_argument if A is not symmetric, or if A cannot be
2723 * multiplied by B
2824 */
29- template <typename TA, int RA, int CA, typename TB, int RB, int CB,
30- require_any_fvar_t <TA, TB>...>
31- inline Eigen::Matrix<return_type_t <TA, TB>, CB, CB> quad_form_sym (
32- const Eigen::Matrix<TA, RA, CA>& A, const Eigen::Matrix<TB, RB, CB>& B) {
33- using T = return_type_t <TA, TB>;
25+ template <typename EigMat1, typename EigMat2,
26+ require_all_eigen_t <EigMat1, EigMat2>* = nullptr ,
27+ require_not_eigen_col_vector_t <EigMat2>* = nullptr ,
28+ require_any_vt_fvar<EigMat1, EigMat2>* = nullptr >
29+ inline promote_scalar_t <return_type_t <EigMat1, EigMat2>, EigMat2> quad_form_sym (
30+ const EigMat1& A, const EigMat2& B) {
31+ using T_ret = return_type_t <EigMat1, EigMat2>;
3432 check_multiplicable (" quad_form_sym" , " A" , A, " B" , B);
3533 check_symmetric (" quad_form_sym" , " A" , A);
36- Eigen::Matrix<T, CB, CB> ret (multiply (transpose (B), multiply (A, B)));
37- return T (0.5 ) * (ret + transpose (ret));
34+ promote_scalar_t <T_ret, EigMat2> ret (
35+ multiply (B.transpose (), multiply (A, B)));
36+ return T_ret (0.5 ) * (ret + ret.transpose ());
3837}
3938
4039/* *
4140 * Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
4241 *
43- * @tparam TA type of elements in the symmetric matrix
44- * @tparam RA number of rows in the symmetric matrix, can be Eigen::Dynamic
45- * @tparam CA number of columns in the symmetric matrix, can be Eigen::Dynamic
46- * @tparam TB type of elements in the vector
47- * @tparam RB number of rows in the vector, can be Eigen::Dynamic
42+ * @tparam EigMat type of the (symmetric) matrix
43+ * @tparam ColVec type of the vector
4844 *
4945 * @param A symmetric matrix
5046 * @param B vector
5147 * @return The quadratic form (a scalar).
5248 * @throws std::invalid_argument if A is not symmetric, or if A cannot be
5349 * multiplied by B
5450 */
55- template <typename TA, int RA, int CA, typename TB, int RB,
56- require_any_fvar_t <TA, TB>...>
57- inline return_type_t <TA, TB> quad_form_sym (const Eigen::Matrix<TA, RA, CA>& A,
58- const Eigen::Matrix<TB, RB, 1 >& B) {
51+ template <typename EigMat, typename ColVec, require_eigen_t <EigMat>* = nullptr ,
52+ require_eigen_col_vector_t <ColVec>* = nullptr ,
53+ require_any_vt_fvar<EigMat, ColVec>* = nullptr >
54+ inline return_type_t <EigMat, ColVec> quad_form_sym (const EigMat& A,
55+ const ColVec& B) {
5956 check_multiplicable (" quad_form_sym" , " A" , A, " B" , B);
6057 check_symmetric (" quad_form_sym" , " A" , A);
6158 return dot_product (B, multiply (A, B));
0 commit comments