Skip to content

Commit 3b2dbbd

Browse files
authored
Merge pull request #1848 from bstatcomp/generalize_fun_s_t
Generalize functions starting with s and t
2 parents 17a67ad + 72af837 commit 3b2dbbd

28 files changed

Lines changed: 325 additions & 363 deletions

stan/math/fwd/fun/softmax.hpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,39 @@
88
namespace stan {
99
namespace math {
1010

11-
template <typename T>
12-
inline Eigen::Matrix<fvar<T>, Eigen::Dynamic, 1> softmax(
13-
const Eigen::Matrix<fvar<T>, Eigen::Dynamic, 1>& alpha) {
11+
template <typename ColVec,
12+
require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr>
13+
inline auto softmax(const ColVec& alpha) {
1414
using Eigen::Dynamic;
1515
using Eigen::Matrix;
16+
using T = typename value_type_t<ColVec>::Scalar;
1617

1718
Matrix<T, Dynamic, 1> alpha_t(alpha.size());
1819
for (int k = 0; k < alpha.size(); ++k) {
19-
alpha_t(k) = alpha(k).val_;
20+
alpha_t.coeffRef(k) = alpha.coeff(k).val_;
2021
}
2122

2223
Matrix<T, Dynamic, 1> softmax_alpha_t = softmax(alpha_t);
2324

2425
Matrix<fvar<T>, Dynamic, 1> softmax_alpha(alpha.size());
2526
for (int k = 0; k < alpha.size(); ++k) {
26-
softmax_alpha(k).val_ = softmax_alpha_t(k);
27-
softmax_alpha(k).d_ = 0;
27+
softmax_alpha.coeffRef(k).val_ = softmax_alpha_t.coeff(k);
28+
softmax_alpha.coeffRef(k).d_ = 0;
2829
}
2930

3031
for (int m = 0; m < alpha.size(); ++m) {
3132
T negative_alpha_m_d_times_softmax_alpha_t_m
32-
= -alpha(m).d_ * softmax_alpha_t(m);
33+
= -alpha.coeff(m).d_ * softmax_alpha_t.coeff(m);
3334
for (int k = 0; k < alpha.size(); ++k) {
3435
if (m == k) {
35-
softmax_alpha(k).d_
36-
+= softmax_alpha_t(k)
37-
* (alpha(m).d_ + negative_alpha_m_d_times_softmax_alpha_t_m);
36+
softmax_alpha.coeffRef(k).d_
37+
+= softmax_alpha_t.coeff(k)
38+
* (alpha.coeff(m).d_
39+
+ negative_alpha_m_d_times_softmax_alpha_t_m);
3840
} else {
39-
softmax_alpha(k).d_
40-
+= negative_alpha_m_d_times_softmax_alpha_t_m * softmax_alpha_t(k);
41+
softmax_alpha.coeffRef(k).d_
42+
+= softmax_alpha_t.coeff(k)
43+
* negative_alpha_m_d_times_softmax_alpha_t_m;
4144
}
4245
}
4346
}

stan/math/fwd/fun/trace_quad_form.hpp

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,16 @@
1111
namespace stan {
1212
namespace math {
1313

14-
template <int RA, int CA, int RB, int CB, typename T>
15-
inline fvar<T> trace_quad_form(const Eigen::Matrix<fvar<T>, RA, CA> &A,
16-
const Eigen::Matrix<fvar<T>, RB, CB> &B) {
14+
template <typename EigMat1, typename EigMat2,
15+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
16+
require_any_vt_fvar<EigMat1, EigMat2>* = nullptr>
17+
inline return_type_t<EigMat1, EigMat2> trace_quad_form(const EigMat1& A,
18+
const EigMat2& B) {
1719
check_square("trace_quad_form", "A", A);
1820
check_multiplicable("trace_quad_form", "A", A, "B", B);
19-
return trace(multiply(transpose(B), multiply(A, B)));
21+
return B.cwiseProduct(multiply(A, B)).sum();
2022
}
2123

22-
template <int RA, int CA, int RB, int CB, typename T>
23-
inline fvar<T> trace_quad_form(const Eigen::Matrix<fvar<T>, RA, CA> &A,
24-
const Eigen::Matrix<double, RB, CB> &B) {
25-
check_square("trace_quad_form", "A", A);
26-
check_multiplicable("trace_quad_form", "A", A, "B", B);
27-
return trace(multiply(transpose(B), multiply(A, B)));
28-
}
29-
30-
template <int RA, int CA, int RB, int CB, typename T>
31-
inline fvar<T> trace_quad_form(const Eigen::Matrix<double, RA, CA> &A,
32-
const Eigen::Matrix<fvar<T>, RB, CB> &B) {
33-
check_square("trace_quad_form", "A", A);
34-
check_multiplicable("trace_quad_form", "A", A, "B", B);
35-
return trace(multiply(transpose(B), multiply(A, B)));
36-
}
3724
} // namespace math
3825
} // namespace stan
3926

stan/math/prim/fun/scale_matrix_exp_multiply.hpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@ namespace math {
1616
*
1717
* Specialized for double values for efficiency.
1818
*
19-
* @tparam Cb number of columns in matrix B, can be Eigen::Dynamic
19+
* @tparam EigMat1 type of the first matrix
20+
* @tparam EigMat2 type of the second matrix
21+
*
2022
* @param[in] A Matrix
2123
* @param[in] B Matrix
2224
* @param[in] t double
23-
* @return exponential of At multiplies B
25+
* @return exponential of At multiplied by B
2426
*/
25-
template <int Cb>
26-
inline Eigen::Matrix<double, -1, Cb> scale_matrix_exp_multiply(
27-
const double& t, const Eigen::MatrixXd& A,
28-
const Eigen::Matrix<double, -1, Cb>& B) {
27+
template <typename EigMat1, typename EigMat2,
28+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
29+
require_all_vt_same<double, EigMat1, EigMat2>* = nullptr>
30+
inline Eigen::Matrix<double, Eigen::Dynamic, EigMat2::ColsAtCompileTime>
31+
scale_matrix_exp_multiply(const double& t, const EigMat1& A, const EigMat2& B) {
2932
check_square("scale_matrix_exp_multiply", "input matrix", A);
3033
check_multiplicable("scale_matrix_exp_multiply", "A", A, "B", B);
3134
if (A.size() == 0) {
@@ -39,20 +42,22 @@ inline Eigen::Matrix<double, -1, Cb> scale_matrix_exp_multiply(
3942
* Return product of exp(At) and B, where A is a NxN matrix,
4043
* B is a NxCb matrix and t is a scalar.
4144
*
42-
* Generic implementation when arguments are not double.
45+
* Generic implementation when arguments are not all double.
4346
*
44-
* @tparam Ta scalar type matrix A
45-
* @tparam Tb scalar type matrix B
46-
* @tparam Cb number of columns in matrix B, can be Eigen::Dynamic
47+
* @tparam Tt type of \c t
48+
* @tparam EigMat1 type of the first matrix
49+
* @tparam EigMat2 type of the second matrix
4750
* @param[in] A Matrix
4851
* @param[in] B Matrix
4952
* @param[in] t double
50-
* @return exponential of At multiplies B
53+
* @return exponential of At multiplied by B
5154
*/
52-
template <typename Tt, typename Ta, typename Tb, int Cb>
53-
inline Eigen::Matrix<return_type_t<Tt, Ta, Tb>, -1, Cb>
54-
scale_matrix_exp_multiply(const Tt& t, const Eigen::Matrix<Ta, -1, -1>& A,
55-
const Eigen::Matrix<Tb, -1, Cb>& B) {
55+
template <typename Tt, typename EigMat1, typename EigMat2,
56+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
57+
require_any_not_vt_same<double, Tt, EigMat1, EigMat2>* = nullptr>
58+
inline Eigen::Matrix<return_type_t<Tt, EigMat1, EigMat2>, Eigen::Dynamic,
59+
EigMat2::ColsAtCompileTime>
60+
scale_matrix_exp_multiply(const Tt& t, const EigMat1& A, const EigMat2& B) {
5661
check_square("scale_matrix_exp_multiply", "input matrix", A);
5762
check_multiplicable("scale_matrix_exp_multiply", "A", A, "B", B);
5863
if (A.size() == 0) {

stan/math/prim/fun/sd.hpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,17 @@ namespace math {
1313

1414
/**
1515
* Returns the unbiased sample standard deviation of the
16-
* coefficients in the specified column vector.
16+
* coefficients in the specified std vector, column vector, row vector, or
17+
* matrix.
1718
*
18-
* @tparam T type of elements in the vector
19-
* @param v Specified vector.
20-
* @return Sample variance of vector.
21-
*/
22-
template <typename T>
23-
inline return_type_t<T> sd(const std::vector<T>& v) {
24-
check_nonzero_size("sd", "v", v);
25-
if (v.size() == 1) {
26-
return 0.0;
27-
}
28-
return sqrt(variance(v));
29-
}
30-
31-
/**
32-
* Returns the unbiased sample standard deviation of the
33-
* coefficients in the specified vector, row vector, or matrix.
34-
*
35-
* @tparam T type of elements in the vector, row vector, or matrix
36-
* @tparam R number of rows, can be Eigen::Dynamic
37-
* @tparam C number of columns, can be Eigen::Dynamic
19+
* @tparam T type of the container
3820
*
39-
* @param m Specified vector, row vector or matrix.
21+
* @param m Specified container.
4022
* @return Sample variance.
4123
*/
42-
template <typename T, int R, int C>
43-
inline return_type_t<T> sd(const Eigen::Matrix<T, R, C>& m) {
24+
template <typename T, require_container_t<T>* = nullptr,
25+
require_not_vt_var<T>* = nullptr>
26+
inline return_type_t<T> sd(const T& m) {
4427
using std::sqrt;
4528
check_nonzero_size("sd", "m", m);
4629
if (m.size() == 1) {

stan/math/prim/fun/segment.hpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,23 @@ namespace stan {
99
namespace math {
1010

1111
/**
12-
* Return the specified number of elements as a vector starting
13-
* from the specified element - 1 of the specified vector.
12+
* Return the specified number of elements as a row/column vector starting
13+
* from the specified element - 1 of the specified row/column vector.
1414
*
15-
* @tparam T type of elements in the vector
15+
* @tparam T type of the vector
1616
*/
17-
template <typename T>
18-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> segment(
19-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& v, size_t i, size_t n) {
17+
template <typename EigVec, require_eigen_vector_t<EigVec>* = nullptr>
18+
inline plain_type_t<EigVec> segment(const EigVec& v, size_t i, size_t n) {
2019
check_greater("segment", "n", i, 0.0);
21-
check_less_or_equal("segment", "n", i, static_cast<size_t>(v.rows()));
20+
check_less_or_equal("segment", "n", i, static_cast<size_t>(v.size()));
2221
if (n != 0) {
2322
check_greater("segment", "n", i + n - 1, 0.0);
2423
check_less_or_equal("segment", "n", i + n - 1,
25-
static_cast<size_t>(v.rows()));
24+
static_cast<size_t>(v.size()));
2625
}
2726
return v.segment(i - 1, n);
2827
}
2928

30-
template <typename T>
31-
inline Eigen::Matrix<T, 1, Eigen::Dynamic> segment(
32-
const Eigen::Matrix<T, 1, Eigen::Dynamic>& v, size_t i, size_t n) {
33-
check_greater("segment", "n", i, 0.0);
34-
check_less_or_equal("segment", "n", i, static_cast<size_t>(v.cols()));
35-
if (n != 0) {
36-
check_greater("segment", "n", i + n - 1, 0.0);
37-
check_less_or_equal("segment", "n", i + n - 1,
38-
static_cast<size_t>(v.cols()));
39-
}
40-
41-
return v.segment(i - 1, n);
42-
}
43-
4429
template <typename T>
4530
std::vector<T> segment(const std::vector<T>& sv, size_t i, size_t n) {
4631
check_greater("segment", "i", i, 0.0);

stan/math/prim/fun/simplex_constrain.hpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,27 @@ namespace math {
2020
*
2121
* The transform is based on a centered stick-breaking process.
2222
*
23-
* @tparam T type of elements in the vector
23+
* @tparam ColVec type of the vector
2424
* @param y Free vector input of dimensionality K - 1.
2525
* @return Simplex of dimensionality K.
2626
*/
27-
template <typename T>
28-
Eigen::Matrix<T, Eigen::Dynamic, 1> simplex_constrain(
29-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& y) {
27+
template <typename ColVec, require_eigen_col_vector_t<ColVec>* = nullptr>
28+
auto simplex_constrain(const ColVec& y) {
3029
// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
3130
using Eigen::Dynamic;
3231
using Eigen::Matrix;
3332
using std::log;
34-
using size_type = index_type_t<Matrix<T, Dynamic, 1>>;
33+
using T = value_type_t<ColVec>;
3534

3635
int Km1 = y.size();
3736
Matrix<T, Dynamic, 1> x(Km1 + 1);
3837
T stick_len(1.0);
39-
for (size_type k = 0; k < Km1; ++k) {
40-
T z_k(inv_logit(y(k) - log(Km1 - k)));
41-
x(k) = stick_len * z_k;
42-
stick_len -= x(k);
38+
for (Eigen::Index k = 0; k < Km1; ++k) {
39+
T z_k = inv_logit(y.coeff(k) - log(Km1 - k));
40+
x.coeffRef(k) = stick_len * z_k;
41+
stick_len -= x.coeff(k);
4342
}
44-
x(Km1) = stick_len;
43+
x.coeffRef(Km1) = stick_len;
4544
return x;
4645
}
4746

@@ -53,34 +52,32 @@ Eigen::Matrix<T, Eigen::Dynamic, 1> simplex_constrain(
5352
* The simplex transform is defined through a centered
5453
* stick-breaking process.
5554
*
56-
* @tparam T type of elements in the vector
55+
* @tparam ColVec type of the vector
5756
* @param y Free vector input of dimensionality K - 1.
5857
* @param lp Log probability reference to increment.
5958
* @return Simplex of dimensionality K.
6059
*/
61-
template <typename T>
62-
Eigen::Matrix<T, Eigen::Dynamic, 1> simplex_constrain(
63-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& y, T& lp) {
60+
template <typename ColVec, require_eigen_col_vector_t<ColVec>* = nullptr>
61+
auto simplex_constrain(const ColVec& y, value_type_t<ColVec>& lp) {
6462
using Eigen::Dynamic;
6563
using Eigen::Matrix;
6664
using std::log;
67-
68-
using size_type = index_type_t<Matrix<T, Dynamic, 1>>;
65+
using T = value_type_t<ColVec>;
6966

7067
int Km1 = y.size(); // K = Km1 + 1
7168
Matrix<T, Dynamic, 1> x(Km1 + 1);
7269
T stick_len(1.0);
73-
for (size_type k = 0; k < Km1; ++k) {
70+
for (Eigen::Index k = 0; k < Km1; ++k) {
7471
double eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
75-
T adj_y_k(y(k) + eq_share);
76-
T z_k(inv_logit(adj_y_k));
77-
x(k) = stick_len * z_k;
72+
T adj_y_k = y.coeff(k) + eq_share;
73+
T z_k = inv_logit(adj_y_k);
74+
x.coeffRef(k) = stick_len * z_k;
7875
lp += log(stick_len);
7976
lp -= log1p_exp(-adj_y_k);
8077
lp -= log1p_exp(adj_y_k);
81-
stick_len -= x(k); // equivalently *= (1 - z_k);
78+
stick_len -= x.coeff(k); // equivalently *= (1 - z_k);
8279
}
83-
x(Km1) = stick_len; // no Jacobian contrib for last dim
80+
x.coeffRef(Km1) = stick_len; // no Jacobian contrib for last dim
8481
return x;
8582
}
8683

stan/math/prim/fun/simplex_free.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,25 @@ namespace math {
1919
* <p>The simplex transform is defined through a centered
2020
* stick-breaking process.
2121
*
22-
* @tparam T type of elements in the simplex
22+
* @tparam ColVec type of the simplex (must be a column vector)
2323
* @param x Simplex of dimensionality K.
2424
* @return Free vector of dimensionality (K-1) that transforms to
2525
* the simplex.
2626
* @throw std::domain_error if x is not a valid simplex
2727
*/
28-
template <typename T>
29-
Eigen::Matrix<T, Eigen::Dynamic, 1> simplex_free(
30-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& x) {
28+
template <typename ColVec, require_eigen_col_vector_t<ColVec>* = nullptr>
29+
auto simplex_free(const ColVec& x) {
3130
using std::log;
32-
using size_type = index_type_t<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
31+
using T = value_type_t<ColVec>;
3332

3433
check_simplex("stan::math::simplex_free", "Simplex variable", x);
3534
int Km1 = x.size() - 1;
3635
Eigen::Matrix<T, Eigen::Dynamic, 1> y(Km1);
37-
T stick_len(x(Km1));
38-
for (size_type k = Km1; --k >= 0;) {
39-
stick_len += x(k);
40-
T z_k(x(k) / stick_len);
41-
y(k) = logit(z_k) + log(Km1 - k);
36+
T stick_len = x.coeff(Km1);
37+
for (Eigen::Index k = Km1; --k >= 0;) {
38+
stick_len += x.coeff(k);
39+
T z_k = x.coeff(k) / stick_len;
40+
y.coeffRef(k) = logit(z_k) + log(Km1 - k);
4241
// note: log(Km1 - k) = logit(1.0 / (Km1 + 1 - k));
4342
}
4443
return y;

stan/math/prim/fun/singular_values.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_PRIM_FUN_SINGULAR_VALUES_HPP
22
#define STAN_MATH_PRIM_FUN_SINGULAR_VALUES_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/fun/Eigen.hpp>
56

67
namespace stan {
@@ -12,18 +13,19 @@ namespace math {
1213
* <p>See the documentation for <code>svd()</code> for
1314
* information on the singular values.
1415
*
15-
* @tparam T type of elements in the matrix
16+
* @tparam EigMat type of the matrix
1617
* @param m Specified matrix.
1718
* @return Singular values of the matrix.
1819
*/
19-
template <typename T>
20-
Eigen::Matrix<T, Eigen::Dynamic, 1> singular_values(
21-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
20+
template <typename EigMat, require_eigen_matrix_t<EigMat>* = nullptr>
21+
Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, 1> singular_values(
22+
const EigMat& m) {
2223
if (m.size() == 0) {
2324
return {};
2425
}
2526

26-
return Eigen::JacobiSVD<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> >(m)
27+
return Eigen::JacobiSVD<Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic,
28+
Eigen::Dynamic> >(m)
2729
.singularValues();
2830
}
2931

0 commit comments

Comments
 (0)