Skip to content

Commit 96f7dbc

Browse files
committed
addressed review comments
1 parent 1475ff7 commit 96f7dbc

8 files changed

Lines changed: 39 additions & 21 deletions

File tree

stan/math/prim/fun/simplex_constrain.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace math {
2020
*
2121
* The transform is based on a centered stick-breaking process.
2222
*
23-
* @tparam T type of the vector
23+
* @tparam ColVec type of the vector
2424
* @param y Free vector input of dimensionality K - 1.
2525
* @return Simplex of dimensionality K.
2626
*/

stan/math/prim/fun/simplex_free.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace math {
1919
* <p>The simplex transform is defined through a centered
2020
* stick-breaking process.
2121
*
22-
* @tparam T type of elements in the simplex
22+
* @tparam ColVec type of the simplex (must be a column vector)
2323
* @param x Simplex of dimensionality K.
2424
* @return Free vector of dimensionality (K-1) that transforms to
2525
* the simplex.

stan/math/prim/fun/singular_values.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace math {
1313
* <p>See the documentation for <code>svd()</code> for
1414
* information on the singular values.
1515
*
16-
* @tparam T type of elements in the matrix
16+
* @tparam EigMat type of the matrix
1717
* @param m Specified matrix.
1818
* @return Singular values of the matrix.
1919
*/

stan/math/prim/fun/size_mvt.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ namespace math {
2020
* only match scalars.
2121
* @throw std::invalid_argument since the type is a scalar.
2222
*/
23-
template <typename T, require_stan_scalar_t<T>* = nullptr>
24-
size_t size_mvt(const T& /* unused */) {
23+
template <typename ScalarT, require_stan_scalar_t<ScalarT>* = nullptr>
24+
size_t size_mvt(const ScalarT& /* unused */) {
2525
throw std::invalid_argument("size_mvt passed to an unrecognized type.");
2626
return 1U;
2727
}
2828

29-
template <typename T, require_eigen_t<T>* = nullptr>
30-
size_t size_mvt(const T& /* unused */) {
29+
template <typename EigenT, require_eigen_t<EigenT>* = nullptr>
30+
size_t size_mvt(const EigenT& /* unused */) {
3131
return 1U;
3232
}
3333

stan/math/prim/fun/softmax.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace math {
3737
* \end{array}
3838
* \f$
3939
*
40-
* @tparam T type of elements in the vector
40+
* @tparam ColVec type of elements in the vector
4141
* @param[in] v Vector to transform.
4242
* @return Unit simplex result of the softmax transform of the vector.
4343
*/

stan/math/prim/fun/to_array_1d.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ namespace math {
1414
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
1515
inline std::vector<value_type_t<EigMat>> to_array_1d(const EigMat& matrix) {
1616
using T_val = value_type_t<EigMat>;
17-
const Eigen::Ref<const Eigen::Matrix<T_val, EigMat::RowsAtCompileTime,
18-
EigMat::ColsAtCompileTime>>& mat_ref
17+
std::vector<T_val> result(matrix.size());
18+
Eigen::Map<Eigen::Matrix<T_val, EigMat::RowsAtCompileTime,
19+
EigMat::ColsAtCompileTime>>(result.data(), matrix.rows(),
20+
matrix.cols())
1921
= matrix;
20-
int matrix_size = matrix.size();
21-
std::vector<T_val> result(matrix_size);
22-
for (int i = 0; i < matrix_size; i++) {
23-
result[i] = mat_ref.data()[i];
24-
}
2522
return result;
2623
}
2724

stan/math/rev/fun/sd.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,10 @@ var sd(const T& m) {
3535
using T_map = std::decay_t<decltype(dtrs_map)>;
3636
using T_vi = promote_scalar_t<vari*, T_map>;
3737
using T_d = promote_scalar_t<double, T_map>;
38-
vari** varis
39-
= reinterpret_cast<vari**>(ChainableStack::instance_->memalloc_.alloc(
40-
dtrs_map.size() * sizeof(vari*)));
41-
double* partials
42-
= reinterpret_cast<double*>(ChainableStack::instance_->memalloc_.alloc(
43-
dtrs_map.size() * sizeof(double)));
38+
vari** varis = ChainableStack::instance_->memalloc_.alloc_array<vari*>(
39+
dtrs_map.size());
40+
double* partials = ChainableStack::instance_->memalloc_.alloc_array<double>(
41+
dtrs_map.size());
4442
Eigen::Map<T_vi> varis_map(varis, dtrs_map.rows(), dtrs_map.cols());
4543
Eigen::Map<T_d> partials_map(partials, dtrs_map.rows(), dtrs_map.cols());
4644

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <stan/math/prim.hpp>
2+
#include <test/unit/math/prim/fun/expect_matrix_eq.hpp>
3+
#include <gtest/gtest.h>
4+
#include <vector>
5+
#include <stdexcept>
6+
7+
using stan::math::to_array_1d;
8+
9+
TEST(MathMatrix, to_array_1d_matrix){
10+
Eigen::MatrixXd a(3,3);
11+
a << 1,2,3,4,5,6,7,8,9;
12+
std::vector<double> a_correct{1,4,7,2,5,8,3,6,9};
13+
std::vector<double> a_res = to_array_1d(a);
14+
expect_std_vector_eq(a_res, a_correct);
15+
}
16+
17+
TEST(MathMatrix, to_array_1d_matrix_block){
18+
Eigen::MatrixXd a(3,3);
19+
a << 1,2,3,4,5,6,7,8,9;
20+
std::vector<double> a_correct{2,5,3,6};
21+
std::vector<double> a_res = to_array_1d(a.block(0,1,2,2));
22+
expect_std_vector_eq(a_res, a_correct);
23+
}

0 commit comments

Comments
 (0)