Skip to content

Commit 7fb33f7

Browse files
authored
Merge pull request #1843 from stan-dev/feature/rearrange_reduce_sum_args
Reorder reduce sum args
2 parents 8ef0522 + b2ebd1a commit 7fb33f7

7 files changed

Lines changed: 26 additions & 26 deletions

File tree

stan/math/fwd/functor/reduce_sum.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct reduce_sum_impl {
3838
* This specialization is not parallelized and works for any autodiff types.
3939
*
4040
* An instance, f, of `ReduceFunction` should have the signature:
41-
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs,
41+
* T f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs,
4242
* Args&&... args)
4343
*
4444
* `ReduceFunction` must be default constructible without any arguments
@@ -73,7 +73,7 @@ struct reduce_sum_impl {
7373
}
7474

7575
if (auto_partitioning) {
76-
return ReduceFunction()(0, vmapped.size() - 1, std::forward<Vec>(vmapped),
76+
return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
7777
msgs, std::forward<Args>(args)...);
7878
} else {
7979
return_type_t<Vec, Args...> sum = 0.0;
@@ -88,7 +88,7 @@ struct reduce_sum_impl {
8888
sub_slice.emplace_back(vmapped[i]);
8989
}
9090

91-
sum += ReduceFunction()(start, end, std::forward<Vec>(sub_slice), msgs,
91+
sum += ReduceFunction()(std::forward<Vec>(sub_slice), start, end, msgs,
9292
std::forward<Args>(args)...);
9393
}
9494
return sum;

stan/math/prim/functor/reduce_sum.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
8585

8686
sum_ += apply(
8787
[&](auto&&... args) {
88-
return ReduceFunction()(r.begin(), r.end() - 1, sub_slice, msgs_,
88+
return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, msgs_,
8989
args...);
9090
},
9191
args_tuple_);
@@ -107,7 +107,7 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
107107
* arithmetic types.
108108
*
109109
* ReduceFunction must define an operator() with the same signature as:
110-
* double f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs,
110+
* double f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs,
111111
* Args&&... args)
112112
*
113113
* `ReduceFunction` must be default constructible without any arguments
@@ -174,7 +174,7 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
174174
* This defers to reduce_sum_impl for the appropriate implementation
175175
*
176176
* ReduceFunction must define an operator() with the same signature as:
177-
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&...
177+
* T f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs, Args&&...
178178
* args)
179179
*
180180
* `ReduceFunction` must be default constructible without any arguments
@@ -209,7 +209,7 @@ inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
209209
return return_type(0.0);
210210
}
211211

212-
return ReduceFunction()(0, vmapped.size() - 1, std::forward<Vec>(vmapped),
212+
return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
213213
msgs, std::forward<Args>(args)...);
214214
#endif
215215
}

stan/math/prim/functor/reduce_sum_static.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace math {
2121
* This defers to reduce_sum_impl for the appropriate implementation
2222
*
2323
* ReduceFunction must define an operator() with the same signature as:
24-
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&...
24+
* T f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs, Args&&...
2525
* args)
2626
*
2727
* `ReduceFunction` must be default constructible without any arguments
@@ -58,7 +58,7 @@ auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs,
5858
return return_type(0);
5959
}
6060

61-
return ReduceFunction()(0, vmapped.size() - 1, std::forward<Vec>(vmapped),
61+
return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
6262
msgs, std::forward<Args>(args)...);
6363
#endif
6464
}

stan/math/rev/functor/reduce_sum.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
116116
// Perform calculation
117117
var sub_sum_v = apply(
118118
[&](auto&&... args) {
119-
return ReduceFunction()(r.begin(), r.end() - 1, local_sub_slice,
119+
return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1,
120120
msgs_, args...);
121121
},
122122
args_tuple_local_copy);
@@ -164,7 +164,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
164164
* mode autodiff.
165165
*
166166
* ReduceFunction must define an operator() with the same signature as:
167-
* var f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs,
167+
* var f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs,
168168
* Args&&... args)
169169
*
170170
* `ReduceFunction` must be default constructible without any arguments

test/unit/math/prim/functor/reduce_sum_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ TEST(StanMathPrim_reduce_sum, sum) {
237237
std::vector<int> threading_test_global;
238238
struct threading_test_lpdf {
239239
template <typename T1>
240-
inline auto operator()(std::size_t start, std::size_t end,
241-
const std::vector<T1>&, std::ostream* msgs) const {
240+
inline auto operator()(const std::vector<T1>&, std::size_t start,
241+
std::size_t end, std::ostream* msgs) const {
242242
threading_test_global[start] = tbb::this_task_arena::current_thread_index();
243243

244244
return stan::return_type_t<T1>(0);

test/unit/math/prim/functor/reduce_sum_util.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ struct count_lpdf {
2121
count_lpdf() {}
2222

2323
// does the reduction in the sub-slice start to end
24-
inline T operator()(std::size_t start, std::size_t end,
25-
const std::vector<int>& sub_slice, std::ostream* msgs,
24+
inline T operator()(const std::vector<int>& sub_slice, std::size_t start,
25+
std::size_t end, std::ostream* msgs,
2626
const std::vector<T>& lambda,
2727
const std::vector<int>& idata) const {
2828
return stan::math::poisson_lpmf(sub_slice, lambda[0]);
@@ -34,8 +34,8 @@ struct nesting_count_lpdf {
3434
nesting_count_lpdf() {}
3535

3636
// does the reduction in the sub-slice start to end
37-
inline T operator()(std::size_t start, std::size_t end,
38-
const std::vector<int>& sub_slice, std::ostream* msgs,
37+
inline T operator()(const std::vector<int>& sub_slice, std::size_t start,
38+
std::size_t end, std::ostream* msgs,
3939
const std::vector<T>& lambda,
4040
const std::vector<int>& idata) const {
4141
return stan::math::reduce_sum<count_lpdf<T>>(sub_slice, 5, msgs, lambda,
@@ -64,7 +64,7 @@ struct sum_lpdf {
6464
}
6565

6666
template <typename T, typename... Args>
67-
inline auto operator()(std::size_t start, std::size_t end, T&& sub_slice,
67+
inline auto operator()(T&& sub_slice, std::size_t start, std::size_t end,
6868
std::ostream* msgs, Args&&... args) const {
6969
using return_type = stan::return_type_t<T, Args...>;
7070

@@ -77,7 +77,7 @@ struct sum_lpdf {
7777

7878
struct start_end_lpdf {
7979
template <typename T1, typename T2>
80-
inline auto operator()(std::size_t start, std::size_t end, T1&&,
80+
inline auto operator()(T1&&, std::size_t start, std::size_t end,
8181
std::ostream* msgs, T2&& data) const {
8282
stan::return_type_t<T1, T2> sum = 0;
8383
EXPECT_GE(start, 0);
@@ -97,8 +97,8 @@ struct slice_group_count_lpdf {
9797
slice_group_count_lpdf() {}
9898

9999
// does the reduction in the sub-slice start to end
100-
inline T operator()(std::size_t start, std::size_t end,
101-
const std::vector<T>& lambda_slice, std::ostream* msgs,
100+
inline T operator()(const std::vector<T>& lambda_slice, std::size_t start,
101+
std::size_t end, std::ostream* msgs,
102102
const std::vector<int>& y,
103103
const std::vector<int>& gsidx) const {
104104
const std::size_t num_groups = end - start + 1;
@@ -121,7 +121,7 @@ struct grouped_count_lpdf {
121121

122122
// does the reduction in the sub-slice start to end
123123
template <typename VecInt1, typename VecT, typename VecInt2>
124-
inline T operator()(std::size_t start, std::size_t end, VecInt1&& sub_slice,
124+
inline T operator()(VecInt1&& sub_slice, std::size_t start, std::size_t end,
125125
std::ostream* msgs, VecT&& lambda, VecInt2&& gidx) const {
126126
const std::size_t num_terms = end - start + 1;
127127
std::decay_t<VecT> lambda_slice(num_terms);
@@ -168,8 +168,8 @@ auto reduce_sum_sum_lpdf = [](auto&& data, auto&&... args) {
168168
template <int grainsize>
169169
struct static_check_lpdf {
170170
template <typename T>
171-
inline auto operator()(std::size_t start, std::size_t end,
172-
const std::vector<int>&, std::ostream* msgs,
171+
inline auto operator()(const std::vector<int>&, std::size_t start,
172+
std::size_t end, std::ostream* msgs,
173173
const std::vector<T>& data) const {
174174
T sum = 0;
175175
EXPECT_LE(end - start + 1, grainsize);

test/unit/math/rev/functor/reduce_sum_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,8 @@ TEST(StanMathRev_reduce_sum, slice_group_gradient) {
394394
std::vector<int> threading_test_global;
395395
struct threading_test_lpdf {
396396
template <typename T1>
397-
inline auto operator()(std::size_t start, std::size_t end,
398-
const std::vector<T1>&, std::ostream* msgs) const {
397+
inline auto operator()(const std::vector<T1>&, std::size_t start,
398+
std::size_t end, std::ostream* msgs) const {
399399
threading_test_global[start] = tbb::this_task_arena::current_thread_index();
400400

401401
return stan::return_type_t<T1>(0);

0 commit comments

Comments
 (0)