@@ -41,9 +41,10 @@ namespace math {
4141 * @throw std::out_of_range if any of the indexes are out of range.
4242 */
4343template <typename T1, typename T2, require_st_arithmetic<T1>* = nullptr ,
44- require_st_var<T2>* = nullptr >
45- inline auto csr_matrix_times_vector (int m, int n, const T1& w, const std::vector<int >& v,
46- const std::vector<int >& u, const T2& b) {
44+ require_st_var<T2>* = nullptr >
45+ inline auto csr_matrix_times_vector (int m, int n, const T1& w,
46+ const std::vector<int >& v,
47+ const std::vector<int >& u, const T2& b) {
4748 check_positive (" csr_matrix_times_vector" , " m" , m);
4849 check_positive (" csr_matrix_times_vector" , " n" , n);
4950 check_size_match (" csr_matrix_times_vector" , " n" , n, " b" , b.size ());
@@ -58,12 +59,16 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, const std::vector
5859 std::vector<int , arena_allocator<int >> arena_u (u.begin (), u.end ());
5960 auto arena_w = to_arena (w);
6061 auto arena_b = to_arena (b);
61- Eigen::Map<Eigen::SparseMatrix<scalar_type_t <T1>>> arena_sp_map (m, n, arena_b.size (), arena_v.data (), arena_u.data (), arena_w.data ());
62- using sparse_dense_mul_type = decltype ((arena_sp_map * value_of (arena_b)).eval ());
62+ Eigen::Map<Eigen::SparseMatrix<scalar_type_t <T1>>> arena_sp_map (
63+ m, n, arena_b.size (), arena_v.data (), arena_u.data (), arena_w.data ());
64+ using sparse_dense_mul_type
65+ = decltype ((arena_sp_map * value_of (arena_b)).eval ());
6366 using return_t = return_var_matrix_t <sparse_dense_mul_type, T1, T2>;
6467 arena_t <return_t > result = arena_sp_map * arena_b.val ();
65- reverse_pass_callback ([arena_v, arena_u, arena_w, arena_b, result, m, n]() mutable {
66- Eigen::Map<Eigen::SparseMatrix<scalar_type_t <T1>>> arena_sp_map (m, n, arena_b.size (), arena_v.data (), arena_u.data (), arena_w.data ());
68+ reverse_pass_callback ([arena_v, arena_u, arena_w, arena_b, result, m,
69+ n]() mutable {
70+ Eigen::Map<Eigen::SparseMatrix<scalar_type_t <T1>>> arena_sp_map (
71+ m, n, arena_b.size (), arena_v.data (), arena_u.data (), arena_w.data ());
6772 arena_b.adj () += arena_sp_map.transpose () * result.adj ();
6873 });
6974
0 commit comments