Skip to content

Commit a5aee5b

Browse files
committed
Boolean reductions transitioned from nd_range<2> to nd_range<1>
1 parent 5c4f980 commit a5aee5b

1 file changed

Lines changed: 53 additions & 41 deletions

File tree

dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,12 @@ template <typename T> struct boolean_predicate
5555
}
5656
};
5757

58-
template <typename inpT,
59-
typename outT,
60-
typename PredicateT,
61-
std::uint8_t wg_dim = 2>
58+
template <typename inpT, typename outT, typename PredicateT>
6259
struct all_reduce_wg_contig
6360
{
64-
void operator()(sycl::nd_item<wg_dim> &ndit,
61+
void operator()(sycl::nd_item<1> &ndit,
6562
outT *out,
66-
size_t &out_idx,
63+
const size_t &out_idx,
6764
const inpT *start,
6865
const inpT *end) const
6966
{
@@ -82,15 +79,12 @@ struct all_reduce_wg_contig
8279
}
8380
};
8481

85-
template <typename inpT,
86-
typename outT,
87-
typename PredicateT,
88-
std::uint8_t wg_dim = 2>
82+
template <typename inpT, typename outT, typename PredicateT>
8983
struct any_reduce_wg_contig
9084
{
91-
void operator()(sycl::nd_item<wg_dim> &ndit,
85+
void operator()(sycl::nd_item<1> &ndit,
9286
outT *out,
93-
size_t &out_idx,
87+
const size_t &out_idx,
9488
const inpT *start,
9589
const inpT *end) const
9690
{
@@ -109,9 +103,9 @@ struct any_reduce_wg_contig
109103
}
110104
};
111105

112-
template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
106+
template <typename T> struct all_reduce_wg_strided
113107
{
114-
void operator()(sycl::nd_item<wg_dim> &ndit,
108+
void operator()(sycl::nd_item<1> &ndit,
115109
T *out,
116110
const size_t &out_idx,
117111
const T &local_val) const
@@ -129,9 +123,9 @@ template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
129123
}
130124
};
131125

132-
template <typename T, std::uint8_t wg_dim = 2> struct any_reduce_wg_strided
126+
template <typename T> struct any_reduce_wg_strided
133127
{
134-
void operator()(sycl::nd_item<wg_dim> &ndit,
128+
void operator()(sycl::nd_item<1> &ndit,
135129
T *out,
136130
const size_t &out_idx,
137131
const T &local_val) const
@@ -215,26 +209,28 @@ struct ContigBooleanReduction
215209
outT *out_ = nullptr;
216210
GroupOp group_op_;
217211
size_t reduction_max_gid_ = 0;
212+
size_t iter_gws_ = 1;
218213
size_t reductions_per_wi = 16;
219214

220215
public:
221216
ContigBooleanReduction(const argT *inp,
222217
outT *res,
223218
GroupOp group_op,
224219
size_t reduction_size,
220+
size_t iteration_size,
225221
size_t reduction_size_per_wi)
226222
: inp_(inp), out_(res), group_op_(group_op),
227-
reduction_max_gid_(reduction_size),
223+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
228224
reductions_per_wi(reduction_size_per_wi)
229225
{
230226
}
231227

232-
void operator()(sycl::nd_item<2> it) const
228+
void operator()(sycl::nd_item<1> it) const
233229
{
234-
235-
size_t reduction_id = it.get_group(0);
236-
size_t reduction_batch_id = it.get_group(1);
237-
size_t wg_size = it.get_local_range(1);
230+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
231+
const size_t reduction_id = it.get_global_id(0) / red_gws_;
232+
const size_t reduction_batch_id = get_reduction_batch_id(it);
233+
size_t wg_size = it.get_local_range(0);
238234

239235
size_t base = reduction_id * reduction_max_gid_;
240236
size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
@@ -244,6 +240,14 @@ struct ContigBooleanReduction
244240
// in group_op_
245241
group_op_(it, out_, reduction_id, inp_ + start, inp_ + end);
246242
}
243+
244+
private:
245+
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
246+
{
247+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
248+
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
249+
return reduction_batch_id;
250+
}
247251
};
248252

249253
typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -332,7 +336,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
332336
red_ev = exec_q.submit([&](sycl::handler &cgh) {
333337
cgh.depends_on(init_ev);
334338

335-
constexpr std::uint8_t group_dim = 2;
339+
constexpr std::uint8_t dim = 1;
336340

337341
constexpr size_t preferred_reductions_per_wi = 4;
338342
size_t reductions_per_wi =
@@ -344,15 +348,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
344348
(reduction_nelems + reductions_per_wi * wg - 1) /
345349
(reductions_per_wi * wg);
346350

347-
auto gws =
348-
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
349-
auto lws = sycl::range<group_dim>{1, wg};
351+
auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
352+
auto lws = sycl::range<dim>{wg};
350353

351354
cgh.parallel_for<
352355
class boolean_reduction_contig_krn<argTy, resTy, GroupOpT>>(
353-
sycl::nd_range<group_dim>(gws, lws),
356+
sycl::nd_range<dim>(gws, lws),
354357
ContigBooleanReduction<argTy, resTy, GroupOpT>(
355-
arg_tp, res_tp, GroupOpT(), reduction_nelems,
358+
arg_tp, res_tp, GroupOpT(), reduction_nelems, iter_nelems,
356359
reductions_per_wi));
357360
});
358361
}
@@ -404,6 +407,7 @@ struct StridedBooleanReduction
404407
InputOutputIterIndexerT inp_out_iter_indexer_;
405408
InputRedIndexerT inp_reduced_dims_indexer_;
406409
size_t reduction_max_gid_ = 0;
410+
size_t iter_gws_ = 1;
407411
size_t reductions_per_wi = 16;
408412

409413
public:
@@ -415,23 +419,24 @@ struct StridedBooleanReduction
415419
InputOutputIterIndexerT arg_res_iter_indexer,
416420
InputRedIndexerT arg_reduced_dims_indexer,
417421
size_t reduction_size,
422+
size_t iteration_size,
418423
size_t reduction_size_per_wi)
419424
: inp_(inp), out_(res), reduction_op_(reduction_op),
420425
group_op_(group_op), identity_(identity_val),
421426
inp_out_iter_indexer_(arg_res_iter_indexer),
422427
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
423-
reduction_max_gid_(reduction_size),
428+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
424429
reductions_per_wi(reduction_size_per_wi)
425430
{
426431
}
427432

428-
void operator()(sycl::nd_item<2> it) const
433+
void operator()(sycl::nd_item<1> it) const
429434
{
430-
431-
size_t reduction_id = it.get_group(0);
432-
size_t reduction_batch_id = it.get_group(1);
433-
size_t reduction_lid = it.get_local_id(1);
434-
size_t wg_size = it.get_local_range(1);
435+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
436+
const size_t reduction_id = it.get_global_id(0) / red_gws_;
437+
const size_t reduction_batch_id = get_reduction_batch_id(it);
438+
const size_t reduction_lid = it.get_local_id(0);
439+
const size_t wg_size = it.get_local_range(0);
435440

436441
auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id);
437442
const py::ssize_t &inp_iter_offset =
@@ -462,6 +467,14 @@ struct StridedBooleanReduction
462467
// in group_op_
463468
group_op_(it, out_, out_iter_offset, local_red_val);
464469
}
470+
471+
private:
472+
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
473+
{
474+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
475+
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
476+
return reduction_batch_id;
477+
}
465478
};
466479

467480
template <typename T1,
@@ -564,7 +577,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
564577
red_ev = exec_q.submit([&](sycl::handler &cgh) {
565578
cgh.depends_on(res_init_ev);
566579

567-
constexpr std::uint8_t group_dim = 2;
580+
constexpr std::uint8_t dim = 1;
568581

569582
using InputOutputIterIndexerT =
570583
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -587,20 +600,19 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
587600
(reduction_nelems + reductions_per_wi * wg - 1) /
588601
(reductions_per_wi * wg);
589602

590-
auto gws =
591-
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
592-
auto lws = sycl::range<group_dim>{1, wg};
603+
auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
604+
auto lws = sycl::range<dim>{wg};
593605

594606
cgh.parallel_for<class boolean_reduction_strided_krn<
595607
argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
596608
ReductionIndexerT>>(
597-
sycl::nd_range<group_dim>(gws, lws),
609+
sycl::nd_range<dim>(gws, lws),
598610
StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
599611
InputOutputIterIndexerT,
600612
ReductionIndexerT>(
601613
arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val,
602614
in_out_iter_indexer, reduction_indexer, reduction_nelems,
603-
reductions_per_wi));
615+
iter_nelems, reductions_per_wi));
604616
});
605617
}
606618
return red_ev;

0 commit comments

Comments
 (0)