@@ -134,12 +134,12 @@ struct ReductionOverGroupWithAtomicFunctor
134134 InputOutputIterIndexerT arg_res_iter_indexer,
135135 InputRedIndexerT arg_reduced_dims_indexer,
136136 size_t reduction_size,
137- size_t iter_gws ,
137+ size_t iteration_size ,
138138 size_t reduction_size_per_wi)
139139 : inp_(data), out_(res), reduction_op_(reduction_op),
140140 identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
141141 inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
142- reduction_max_gid_(reduction_size), iter_gws_(iter_gws ),
142+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size ),
143143 reductions_per_wi(reduction_size_per_wi)
144144 {
145145 }
@@ -528,6 +528,7 @@ struct ReductionOverGroupNoAtomicFunctor
528528 InputOutputIterIndexerT inp_out_iter_indexer_;
529529 InputRedIndexerT inp_reduced_dims_indexer_;
530530 size_t reduction_max_gid_ = 0 ;
531+ size_t iter_gws_ = 1 ;
531532 size_t reductions_per_wi = 16 ;
532533
533534public:
@@ -539,22 +540,25 @@ struct ReductionOverGroupNoAtomicFunctor
539540 InputOutputIterIndexerT arg_res_iter_indexer,
540541 InputRedIndexerT arg_reduced_dims_indexer,
541542 size_t reduction_size,
543+ size_t iteration_size,
542544 size_t reduction_size_per_wi)
543545 : inp_(data), out_(res), reduction_op_(reduction_op),
544546 identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
545547 inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
546- reduction_max_gid_(reduction_size),
548+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
547549 reductions_per_wi(reduction_size_per_wi)
548550 {
549551 }
550552
551- void operator ()(sycl::nd_item<2 > it) const
553+ void operator ()(sycl::nd_item<1 > it) const
552554 {
553555
554- size_t iter_gid = it.get_global_id (0 );
555- size_t reduction_batch_id = it.get_group (1 );
556- size_t reduction_lid = it.get_local_id (1 );
557- size_t wg = it.get_local_range (1 ); // 0 <= reduction_lid < wg
556+ const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
557+ const size_t iter_gid = it.get_global_id (0 ) / red_gws_;
558+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
559+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
560+ const size_t reduction_lid = it.get_local_id (0 );
561+ const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
558562
559563 // work-items sums over input with indices
560564 // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -590,7 +594,7 @@ struct ReductionOverGroupNoAtomicFunctor
590594
591595 if (work_group.leader ()) {
592596 // each group writes to a different memory location
593- out_[out_iter_offset * it. get_group_range ( 1 ) + reduction_batch_id] =
597+ out_[out_iter_offset * n_reduction_groups + reduction_batch_id] =
594598 red_val_over_wg;
595599 }
596600 }
@@ -657,20 +661,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
657661 assert (reduction_groups == 1 );
658662
659663 auto globalRange =
660- sycl::range<2 >{iter_nelems, reduction_groups * wg};
661- auto localRange = sycl::range<2 >{ 1 , wg};
664+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
665+ auto localRange = sycl::range<1 >{ wg};
662666
663667 using KernelName = class sum_reduction_over_group_temps_krn <
664668 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
665669 ReductionIndexerT>;
666670 cgh.parallel_for <KernelName>(
667- sycl::nd_range<2 >(globalRange, localRange),
671+ sycl::nd_range<1 >(globalRange, localRange),
668672 ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
669673 InputOutputIterIndexerT,
670674 ReductionIndexerT>(
671675 arg_tp, res_tp, ReductionOpT (), identity_val,
672676 in_out_iter_indexer, reduction_indexer, reduction_nelems,
673- reductions_per_wi));
677+ iter_nelems, reductions_per_wi));
674678 });
675679
676680 return comp_ev;
@@ -723,20 +727,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
723727 reduction_shape_stride};
724728
725729 auto globalRange =
726- sycl::range<2 >{iter_nelems, reduction_groups * wg};
727- auto localRange = sycl::range<2 >{ 1 , wg};
730+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
731+ auto localRange = sycl::range<1 >{ wg};
728732
729733 using KernelName = class sum_reduction_over_group_temps_krn <
730734 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
731735 ReductionIndexerT>;
732736 cgh.parallel_for <KernelName>(
733- sycl::nd_range<2 >(globalRange, localRange),
737+ sycl::nd_range<1 >(globalRange, localRange),
734738 ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
735739 InputOutputIterIndexerT,
736740 ReductionIndexerT>(
737741 arg_tp, partially_reduced_tmp, ReductionOpT (), identity_val,
738742 in_out_iter_indexer, reduction_indexer, reduction_nelems,
739- preferrered_reductions_per_wi));
743+ iter_nelems, preferrered_reductions_per_wi));
740744 });
741745
742746 size_t remaining_reduction_nelems = reduction_groups;
@@ -778,20 +782,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
778782 ReductionIndexerT reduction_indexer{};
779783
780784 auto globalRange =
781- sycl::range<2 >{iter_nelems, reduction_groups_ * wg};
782- auto localRange = sycl::range<2 >{ 1 , wg};
785+ sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
786+ auto localRange = sycl::range<1 >{ wg};
783787
784788 using KernelName = class sum_reduction_over_group_temps_krn <
785789 resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
786790 ReductionIndexerT>;
787791 cgh.parallel_for <KernelName>(
788- sycl::nd_range<2 >(globalRange, localRange),
792+ sycl::nd_range<1 >(globalRange, localRange),
789793 ReductionOverGroupNoAtomicFunctor<
790794 resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
791795 ReductionIndexerT>(
792796 temp_arg, temp2_arg, ReductionOpT (), identity_val,
793797 in_out_iter_indexer, reduction_indexer,
794- remaining_reduction_nelems,
798+ remaining_reduction_nelems, iter_nelems,
795799 preferrered_reductions_per_wi));
796800 });
797801
@@ -834,20 +838,21 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
834838 assert (reduction_groups == 1 );
835839
836840 auto globalRange =
837- sycl::range<2 >{iter_nelems, reduction_groups * wg};
838- auto localRange = sycl::range<2 >{ 1 , wg};
841+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
842+ auto localRange = sycl::range<1 >{ wg};
839843
840844 using KernelName = class sum_reduction_over_group_temps_krn <
841845 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
842846 ReductionIndexerT>;
843847 cgh.parallel_for <KernelName>(
844- sycl::nd_range<2 >(globalRange, localRange),
848+ sycl::nd_range<1 >(globalRange, localRange),
845849 ReductionOverGroupNoAtomicFunctor<resTy, resTy, ReductionOpT,
846850 InputOutputIterIndexerT,
847851 ReductionIndexerT>(
848852 temp_arg, res_tp, ReductionOpT (), identity_val,
849853 in_out_iter_indexer, reduction_indexer,
850- remaining_reduction_nelems, reductions_per_wi));
854+ remaining_reduction_nelems, iter_nelems,
855+ reductions_per_wi));
851856 });
852857
853858 sycl::event cleanup_host_task_event =
0 commit comments