@@ -122,6 +122,7 @@ struct ReductionOverGroupWithAtomicFunctor
122122 InputOutputIterIndexerT inp_out_iter_indexer_;
123123 InputRedIndexerT inp_reduced_dims_indexer_;
124124 size_t reduction_max_gid_ = 0 ;
125+ size_t iter_gws_ = 1 ;
125126 size_t reductions_per_wi = 16 ;
126127
127128public:
@@ -133,22 +134,23 @@ struct ReductionOverGroupWithAtomicFunctor
133134 InputOutputIterIndexerT arg_res_iter_indexer,
134135 InputRedIndexerT arg_reduced_dims_indexer,
135136 size_t reduction_size,
137+ size_t iter_gws,
136138 size_t reduction_size_per_wi)
137139 : inp_(data), out_(res), reduction_op_(reduction_op),
138140 identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
139141 inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
140- reduction_max_gid_(reduction_size),
142+ reduction_max_gid_(reduction_size), iter_gws_(iter_gws),
141143 reductions_per_wi(reduction_size_per_wi)
142144 {
143145 }
144146
145- void operator ()(sycl::nd_item<2 > it) const
147+ void operator ()(sycl::nd_item<1 > it) const
146148 {
147-
148- size_t iter_gid = it.get_global_id (0 );
149- size_t reduction_batch_id = it. get_group ( 1 );
150- size_t reduction_lid = it.get_local_id (1 );
151- size_t wg = it.get_local_range (1 ); // 0 <= reduction_lid < wg
149+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
150+ const size_t iter_gid = it.get_global_id (0 ) / red_gws_ ;
151+ const size_t reduction_batch_id = get_reduction_batch_id (it );
152+ const size_t reduction_lid = it.get_local_id (0 );
153+ const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
152154
153155 // work-items sums over input with indices
154156 // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -202,6 +204,14 @@ struct ReductionOverGroupWithAtomicFunctor
202204 }
203205 }
204206 }
207+
208+ private:
209+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
210+ {
211+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
212+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
213+ return reduction_batch_id;
214+ }
205215};
206216
207217typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
@@ -343,21 +353,21 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
343353 }
344354
345355 auto globalRange =
346- sycl::range<2 >{iter_nelems, reduction_groups * wg};
347- auto localRange = sycl::range<2 >{ 1 , wg};
356+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
357+ auto localRange = sycl::range<1 >{ wg};
348358
349359 using KernelName = class sum_reduction_over_group_with_atomics_krn <
350360 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
351361 ReductionIndexerT>;
352362
353363 cgh.parallel_for <KernelName>(
354- sycl::nd_range<2 >(globalRange, localRange),
364+ sycl::nd_range<1 >(globalRange, localRange),
355365 ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
356366 InputOutputIterIndexerT,
357367 ReductionIndexerT>(
358368 arg_tp, res_tp, ReductionOpT (), identity_val,
359369 in_out_iter_indexer, reduction_indexer, reduction_nelems,
360- reductions_per_wi));
370+ iter_nelems, reductions_per_wi));
361371 });
362372
363373 return comp_ev;
@@ -480,21 +490,21 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
480490 }
481491
482492 auto globalRange =
483- sycl::range<2 >{iter_nelems, reduction_groups * wg};
484- auto localRange = sycl::range<2 >{ 1 , wg};
493+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
494+ auto localRange = sycl::range<1 >{ wg};
485495
486496 using KernelName = class sum_reduction_over_group_with_atomics_krn <
487497 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
488498 ReductionIndexerT>;
489499
490500 cgh.parallel_for <KernelName>(
491- sycl::nd_range<2 >(globalRange, localRange),
501+ sycl::nd_range<1 >(globalRange, localRange),
492502 ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
493503 InputOutputIterIndexerT,
494504 ReductionIndexerT>(
495505 arg_tp, res_tp, ReductionOpT (), identity_val,
496506 in_out_iter_indexer, reduction_indexer, reduction_nelems,
497- reductions_per_wi));
507+ iter_nelems, reductions_per_wi));
498508 });
499509
500510 return comp_ev;
0 commit comments