@@ -232,6 +232,9 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
232232template <typename T1, typename T2, typename T3, typename T4, typename T5>
233233class sum_reduction_over_group_with_atomics_krn ;
234234
235+ template <typename T1, typename T2>
236+ class sum_reduction_over_group_with_atomics_init_krn ;
237+
235238template <typename T1, typename T2, typename T3, typename T4, typename T5>
236239class sum_reduction_seq_strided_krn ;
237240
@@ -305,13 +308,16 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
305308 iter_shape_and_strides + 2 * iter_nd;
306309 IndexerT res_indexer (iter_nd, iter_res_offset, res_shape,
307310 res_strides);
308-
311+ using InitKernelName =
312+ class sum_reduction_over_group_with_atomics_init_krn <resTy,
313+ argTy>;
309314 cgh.depends_on (depends);
310315
311- cgh.parallel_for (sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
312- auto res_offset = res_indexer (id[0 ]);
313- res_tp[res_offset] = identity_val;
314- });
316+ cgh.parallel_for <InitKernelName>(
317+ sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
318+ auto res_offset = res_indexer (id[0 ]);
319+ res_tp[res_offset] = identity_val;
320+ });
315321 });
316322
317323 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
0 commit comments