|
| 1 | +#ifndef STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP |
| 2 | +#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/fun/Eigen.hpp> |
| 5 | +#include <stan/math/rev/core.hpp> |
| 6 | +#include <stan/math/prim/err.hpp> |
| 7 | +#include <stan/math/prim/fun/csr_u_to_z.hpp> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +namespace stan { |
| 11 | +namespace math { |
| 12 | + |
| 13 | +/** |
| 14 | + * \addtogroup csr_format |
| 15 | + * Return the multiplication of the sparse matrix (specified by |
| 16 | + * by values and indexing) by the specified dense vector. |
| 17 | + * |
| 18 | + * The sparse matrix X of dimension m by n is represented by the |
| 19 | + * vector w (of values), the integer array v (containing one-based |
| 20 | + * column index of each value), the integer array u (containing |
| 21 | + * one-based indexes of where each row starts in w). |
| 22 | + * |
| 23 | + * @tparam T1 type of the sparse matrix |
| 24 | + * @tparam T2 type of the dense vector |
| 25 | + * @param m Number of rows in matrix. |
| 26 | + * @param n Number of columns in matrix. |
| 27 | + * @param w Vector of non-zero values in matrix. |
| 28 | + * @param v Column index of each non-zero value, same |
| 29 | + * length as w. |
| 30 | + * @param u Index of where each row starts in w, length equal to |
| 31 | + * the number of rows plus one. |
| 32 | + * @param b Eigen vector which the matrix is multiplied by. |
| 33 | + * @return Dense vector for the product. |
| 34 | + * @throw std::domain_error if m and n are not positive or are nan. |
| 35 | + * @throw std::domain_error if the implied sparse matrix and b are |
| 36 | + * not multiplicable. |
| 37 | + * @throw std::invalid_argument if m/n/w/v/u are not internally |
| 38 | + * consistent, as defined by the indexing scheme. Extractors are |
| 39 | + * defined in Stan which guarantee a consistent set of m/n/w/v/u |
| 40 | + * for a given sparse matrix. |
| 41 | + * @throw std::out_of_range if any of the indexes are out of range. |
| 42 | + */ |
| 43 | +template <typename T1, typename T2, require_st_arithmetic<T1>* = nullptr, |
| 44 | + require_st_var<T2>* = nullptr> |
| 45 | +inline auto csr_matrix_times_vector(int m, int n, const T1& w, |
| 46 | + const std::vector<int>& v, |
| 47 | + const std::vector<int>& u, const T2& b) { |
| 48 | + check_positive("csr_matrix_times_vector", "m", m); |
| 49 | + check_positive("csr_matrix_times_vector", "n", n); |
| 50 | + check_size_match("csr_matrix_times_vector", "n", n, "b", b.size()); |
| 51 | + check_size_match("csr_matrix_times_vector", "m", m, "u", u.size() - 1); |
| 52 | + check_size_match("csr_matrix_times_vector", "w", w.size(), "v", v.size()); |
| 53 | + check_size_match("csr_matrix_times_vector", "u/z", |
| 54 | + u[m - 1] + csr_u_to_z(u, m - 1) - 1, "v", v.size()); |
| 55 | + for (int i : v) { |
| 56 | + check_range("csr_matrix_times_vector", "v[]", n, i); |
| 57 | + } |
| 58 | + std::vector<int, arena_allocator<int>> arena_v(v.begin(), v.end()); |
| 59 | + std::vector<int, arena_allocator<int>> arena_u(u.begin(), u.end()); |
| 60 | + auto arena_w = to_arena(w); |
| 61 | + auto arena_b = to_arena(b); |
| 62 | + Eigen::Map<Eigen::SparseMatrix<scalar_type_t<T1>>> arena_sp_map( |
| 63 | + m, n, arena_b.size(), arena_v.data(), arena_u.data(), arena_w.data()); |
| 64 | + using sparse_dense_mul_type |
| 65 | + = decltype((arena_sp_map * value_of(arena_b)).eval()); |
| 66 | + using return_t = return_var_matrix_t<sparse_dense_mul_type, T1, T2>; |
| 67 | + arena_t<return_t> result = arena_sp_map * arena_b.val(); |
| 68 | + reverse_pass_callback([arena_v, arena_u, arena_w, arena_b, result, m, |
| 69 | + n]() mutable { |
| 70 | + Eigen::Map<Eigen::SparseMatrix<scalar_type_t<T1>>> arena_sp_map( |
| 71 | + m, n, arena_b.size(), arena_v.data(), arena_u.data(), arena_w.data()); |
| 72 | + arena_b.adj() += arena_sp_map.transpose() * result.adj(); |
| 73 | + }); |
| 74 | + |
| 75 | + return return_t(result); |
| 76 | +} |
| 77 | + |
| 78 | +} // namespace math |
| 79 | +} // namespace stan |
| 80 | + |
| 81 | +#endif |
0 commit comments