Skip to content

Commit f1905f6

Browse files
authored
Merge pull request #2462 from stan-dev/feature/csr-sparse-data-varmat-vector
csr_matrix_time_vector data * var specialization
2 parents eae2d6d + 2a39e03 commit f1905f6

4 files changed

Lines changed: 104 additions & 1 deletion

File tree

stan/math/prim/fun/csr_matrix_times_vector.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ namespace math {
6969
* for a given sparse matrix.
7070
* @throw std::out_of_range if any of the indexes are out of range.
7171
*/
72-
template <typename T1, typename T2>
72+
template <typename T1, typename T2,
73+
require_not_t<conjunction<std::is_arithmetic<scalar_type_t<T1>>,
74+
is_var<scalar_type_t<T2>>>>* = nullptr>
7375
inline Eigen::Matrix<return_type_t<T1, T2>, Eigen::Dynamic, 1>
7476
csr_matrix_times_vector(int m, int n, const T1& w, const std::vector<int>& v,
7577
const std::vector<int>& u, const T2& b) {

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <stan/math/rev/fun/cov_matrix_constrain.hpp>
4242
#include <stan/math/rev/fun/cov_exp_quad.hpp>
4343
#include <stan/math/rev/fun/cov_matrix_constrain_lkj.hpp>
44+
#include <stan/math/rev/fun/csr_matrix_times_vector.hpp>
4445
#include <stan/math/rev/fun/determinant.hpp>
4546
#include <stan/math/rev/fun/diag_pre_multiply.hpp>
4647
#include <stan/math/rev/fun/diag_post_multiply.hpp>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
2+
#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/rev/core.hpp>
6+
#include <stan/math/prim/err.hpp>
7+
#include <stan/math/prim/fun/csr_u_to_z.hpp>
8+
#include <vector>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* \addtogroup csr_format
15+
* Return the multiplication of the sparse matrix (specified by
16+
* by values and indexing) by the specified dense vector.
17+
*
18+
* The sparse matrix X of dimension m by n is represented by the
19+
* vector w (of values), the integer array v (containing one-based
20+
* column index of each value), the integer array u (containing
21+
* one-based indexes of where each row starts in w).
22+
*
23+
* @tparam T1 type of the sparse matrix
24+
* @tparam T2 type of the dense vector
25+
* @param m Number of rows in matrix.
26+
* @param n Number of columns in matrix.
27+
* @param w Vector of non-zero values in matrix.
28+
* @param v Column index of each non-zero value, same
29+
* length as w.
30+
* @param u Index of where each row starts in w, length equal to
31+
* the number of rows plus one.
32+
* @param b Eigen vector which the matrix is multiplied by.
33+
* @return Dense vector for the product.
34+
* @throw std::domain_error if m and n are not positive or are nan.
35+
* @throw std::domain_error if the implied sparse matrix and b are
36+
* not multiplicable.
37+
* @throw std::invalid_argument if m/n/w/v/u are not internally
38+
* consistent, as defined by the indexing scheme. Extractors are
39+
* defined in Stan which guarantee a consistent set of m/n/w/v/u
40+
* for a given sparse matrix.
41+
* @throw std::out_of_range if any of the indexes are out of range.
42+
*/
43+
template <typename T1, typename T2, require_st_arithmetic<T1>* = nullptr,
44+
require_st_var<T2>* = nullptr>
45+
inline auto csr_matrix_times_vector(int m, int n, const T1& w,
46+
const std::vector<int>& v,
47+
const std::vector<int>& u, const T2& b) {
48+
check_positive("csr_matrix_times_vector", "m", m);
49+
check_positive("csr_matrix_times_vector", "n", n);
50+
check_size_match("csr_matrix_times_vector", "n", n, "b", b.size());
51+
check_size_match("csr_matrix_times_vector", "m", m, "u", u.size() - 1);
52+
check_size_match("csr_matrix_times_vector", "w", w.size(), "v", v.size());
53+
check_size_match("csr_matrix_times_vector", "u/z",
54+
u[m - 1] + csr_u_to_z(u, m - 1) - 1, "v", v.size());
55+
for (int i : v) {
56+
check_range("csr_matrix_times_vector", "v[]", n, i);
57+
}
58+
std::vector<int, arena_allocator<int>> arena_v(v.begin(), v.end());
59+
std::vector<int, arena_allocator<int>> arena_u(u.begin(), u.end());
60+
auto arena_w = to_arena(w);
61+
auto arena_b = to_arena(b);
62+
Eigen::Map<Eigen::SparseMatrix<scalar_type_t<T1>>> arena_sp_map(
63+
m, n, arena_b.size(), arena_v.data(), arena_u.data(), arena_w.data());
64+
using sparse_dense_mul_type
65+
= decltype((arena_sp_map * value_of(arena_b)).eval());
66+
using return_t = return_var_matrix_t<sparse_dense_mul_type, T1, T2>;
67+
arena_t<return_t> result = arena_sp_map * arena_b.val();
68+
reverse_pass_callback([arena_v, arena_u, arena_w, arena_b, result, m,
69+
n]() mutable {
70+
Eigen::Map<Eigen::SparseMatrix<scalar_type_t<T1>>> arena_sp_map(
71+
m, n, arena_b.size(), arena_v.data(), arena_u.data(), arena_w.data());
72+
arena_b.adj() += arena_sp_map.transpose() * result.adj();
73+
});
74+
75+
return return_t(result);
76+
}
77+
78+
} // namespace math
79+
} // namespace stan
80+
81+
#endif
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <vector>
3+
4+
TEST(MathMixMatFun, csr_matrix_times_vector) {
5+
auto f = [](const auto& w, const auto& b) {
6+
using stan::math::csr_matrix_times_vector;
7+
std::vector<int> v{1, 2, 0, 2, 4, 2, 1, 4};
8+
std::vector<int> u{0, 2, 4, 5, 6, 8};
9+
return csr_matrix_times_vector(5, 5, w, v, u, b);
10+
};
11+
12+
Eigen::VectorXd w(8);
13+
w << 22, 7, 3, 5, 14, 1, 17, 8;
14+
15+
Eigen::VectorXd b(8);
16+
b << 1, 2, 3, 4, 5, 6, 7, 8;
17+
18+
stan::test::expect_ad(f, w, b);
19+
}

0 commit comments

Comments
 (0)