@@ -55,15 +55,12 @@ template <typename T> struct boolean_predicate
5555 }
5656};
5757
58- template <typename inpT,
59- typename outT,
60- typename PredicateT,
61- std::uint8_t wg_dim = 2 >
58+ template <typename inpT, typename outT, typename PredicateT>
6259struct all_reduce_wg_contig
6360{
64- void operator ()(sycl::nd_item<wg_dim > &ndit,
61+ void operator ()(sycl::nd_item<1 > &ndit,
6562 outT *out,
66- size_t &out_idx,
63+ const size_t &out_idx,
6764 const inpT *start,
6865 const inpT *end) const
6966 {
@@ -82,15 +79,12 @@ struct all_reduce_wg_contig
8279 }
8380};
8481
85- template <typename inpT,
86- typename outT,
87- typename PredicateT,
88- std::uint8_t wg_dim = 2 >
82+ template <typename inpT, typename outT, typename PredicateT>
8983struct any_reduce_wg_contig
9084{
91- void operator ()(sycl::nd_item<wg_dim > &ndit,
85+ void operator ()(sycl::nd_item<1 > &ndit,
9286 outT *out,
93- size_t &out_idx,
87+ const size_t &out_idx,
9488 const inpT *start,
9589 const inpT *end) const
9690 {
@@ -109,9 +103,9 @@ struct any_reduce_wg_contig
109103 }
110104};
111105
112- template <typename T, std:: uint8_t wg_dim = 2 > struct all_reduce_wg_strided
106+ template <typename T> struct all_reduce_wg_strided
113107{
114- void operator ()(sycl::nd_item<wg_dim > &ndit,
108+ void operator ()(sycl::nd_item<1 > &ndit,
115109 T *out,
116110 const size_t &out_idx,
117111 const T &local_val) const
@@ -129,9 +123,9 @@ template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
129123 }
130124};
131125
132- template <typename T, std:: uint8_t wg_dim = 2 > struct any_reduce_wg_strided
126+ template <typename T> struct any_reduce_wg_strided
133127{
134- void operator ()(sycl::nd_item<wg_dim > &ndit,
128+ void operator ()(sycl::nd_item<1 > &ndit,
135129 T *out,
136130 const size_t &out_idx,
137131 const T &local_val) const
@@ -215,26 +209,28 @@ struct ContigBooleanReduction
215209 outT *out_ = nullptr ;
216210 GroupOp group_op_;
217211 size_t reduction_max_gid_ = 0 ;
212+ size_t iter_gws_ = 1 ;
218213 size_t reductions_per_wi = 16 ;
219214
220215public:
221216 ContigBooleanReduction (const argT *inp,
222217 outT *res,
223218 GroupOp group_op,
224219 size_t reduction_size,
220+ size_t iteration_size,
225221 size_t reduction_size_per_wi)
226222 : inp_(inp), out_(res), group_op_(group_op),
227- reduction_max_gid_ (reduction_size),
223+ reduction_max_gid_ (reduction_size), iter_gws_(iteration_size),
228224 reductions_per_wi(reduction_size_per_wi)
229225 {
230226 }
231227
232- void operator ()(sycl::nd_item<2 > it) const
228+ void operator ()(sycl::nd_item<1 > it) const
233229 {
234-
235- size_t reduction_id = it.get_group (0 );
236- size_t reduction_batch_id = it. get_group ( 1 );
237- size_t wg_size = it.get_local_range (1 );
230+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
231+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
232+ const size_t reduction_batch_id = get_reduction_batch_id (it );
233+ size_t wg_size = it.get_local_range (0 );
238234
239235 size_t base = reduction_id * reduction_max_gid_;
240236 size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
@@ -244,6 +240,14 @@ struct ContigBooleanReduction
244240 // in group_op_
245241 group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
246242 }
243+
244+ private:
245+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
246+ {
247+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
248+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
249+ return reduction_batch_id;
250+ }
247251};
248252
249253typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -332,7 +336,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
332336 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
333337 cgh.depends_on (init_ev);
334338
335- constexpr std::uint8_t group_dim = 2 ;
339+ constexpr std::uint8_t dim = 1 ;
336340
337341 constexpr size_t preferred_reductions_per_wi = 4 ;
338342 size_t reductions_per_wi =
@@ -344,15 +348,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
344348 (reduction_nelems + reductions_per_wi * wg - 1 ) /
345349 (reductions_per_wi * wg);
346350
347- auto gws =
348- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
349- auto lws = sycl::range<group_dim>{1 , wg};
351+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
352+ auto lws = sycl::range<dim>{wg};
350353
351354 cgh.parallel_for <
352355 class boolean_reduction_contig_krn <argTy, resTy, GroupOpT>>(
353- sycl::nd_range<group_dim >(gws, lws),
356+ sycl::nd_range<dim >(gws, lws),
354357 ContigBooleanReduction<argTy, resTy, GroupOpT>(
355- arg_tp, res_tp, GroupOpT (), reduction_nelems,
358+ arg_tp, res_tp, GroupOpT (), reduction_nelems, iter_nelems,
356359 reductions_per_wi));
357360 });
358361 }
@@ -404,6 +407,7 @@ struct StridedBooleanReduction
404407 InputOutputIterIndexerT inp_out_iter_indexer_;
405408 InputRedIndexerT inp_reduced_dims_indexer_;
406409 size_t reduction_max_gid_ = 0 ;
410+ size_t iter_gws_ = 1 ;
407411 size_t reductions_per_wi = 16 ;
408412
409413public:
@@ -415,23 +419,24 @@ struct StridedBooleanReduction
415419 InputOutputIterIndexerT arg_res_iter_indexer,
416420 InputRedIndexerT arg_reduced_dims_indexer,
417421 size_t reduction_size,
422+ size_t iteration_size,
418423 size_t reduction_size_per_wi)
419424 : inp_(inp), out_(res), reduction_op_(reduction_op),
420425 group_op_ (group_op), identity_(identity_val),
421426 inp_out_iter_indexer_(arg_res_iter_indexer),
422427 inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
423- reduction_max_gid_(reduction_size),
428+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
424429 reductions_per_wi(reduction_size_per_wi)
425430 {
426431 }
427432
428- void operator ()(sycl::nd_item<2 > it) const
433+ void operator ()(sycl::nd_item<1 > it) const
429434 {
430-
431- size_t reduction_id = it.get_group (0 );
432- size_t reduction_batch_id = it. get_group ( 1 );
433- size_t reduction_lid = it.get_local_id (1 );
434- size_t wg_size = it.get_local_range (1 );
435+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
436+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
437+ const size_t reduction_batch_id = get_reduction_batch_id (it );
438+ const size_t reduction_lid = it.get_local_id (0 );
439+ const size_t wg_size = it.get_local_range (0 );
435440
436441 auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (reduction_id);
437442 const py::ssize_t &inp_iter_offset =
@@ -462,6 +467,14 @@ struct StridedBooleanReduction
462467 // in group_op_
463468 group_op_ (it, out_, out_iter_offset, local_red_val);
464469 }
470+
471+ private:
472+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
473+ {
474+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
475+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
476+ return reduction_batch_id;
477+ }
465478};
466479
467480template <typename T1,
@@ -564,7 +577,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
564577 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
565578 cgh.depends_on (res_init_ev);
566579
567- constexpr std::uint8_t group_dim = 2 ;
580+ constexpr std::uint8_t dim = 1 ;
568581
569582 using InputOutputIterIndexerT =
570583 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -587,20 +600,19 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
587600 (reduction_nelems + reductions_per_wi * wg - 1 ) /
588601 (reductions_per_wi * wg);
589602
590- auto gws =
591- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
592- auto lws = sycl::range<group_dim>{1 , wg};
603+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
604+ auto lws = sycl::range<dim>{wg};
593605
594606 cgh.parallel_for <class boolean_reduction_strided_krn <
595607 argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
596608 ReductionIndexerT>>(
597- sycl::nd_range<group_dim >(gws, lws),
609+ sycl::nd_range<dim >(gws, lws),
598610 StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
599611 InputOutputIterIndexerT,
600612 ReductionIndexerT>(
601613 arg_tp, res_tp, RedOpT (), GroupOpT (), identity_val,
602614 in_out_iter_indexer, reduction_indexer, reduction_nelems,
603- reductions_per_wi));
615+ iter_nelems, reductions_per_wi));
604616 });
605617 }
606618 return red_ev;
0 commit comments