@@ -290,10 +290,6 @@ class DPNPC_id final
290290 get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
291291
292292 iteration_size = 1 ;
293-
294- // make thread private storage for each shape by multiplying memory
295- sycl_output_xyz =
296- reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_size * output_shape_size_in_bytes));
297293 }
298294 }
299295
@@ -400,10 +396,6 @@ class DPNPC_id final
400396 {
401397 axes_shape_strides[i] = input_shape_strides[axes[i]];
402398 }
403-
404- // make thread private storage for each shape by multiplying memory
405- sycl_output_xyz =
406- reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_size * output_shape_size_in_bytes));
407399 }
408400 }
409401
@@ -485,35 +477,30 @@ class DPNPC_id final
485477 {
486478 assert (output_global_id < output_size);
487479
488- // use thread private storage
489- size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
490-
491- get_xyz_by_id_inkernel (output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
492-
493480 for (size_t iit = 0 , oit = 0 ; iit < input_shape_size; ++iit)
494481 {
495482 if (std::find (axes.begin (), axes.end (), iit) == axes.end ())
496483 {
497- input_global_id += (sycl_output_xyz_thread[oit] * input_shape_strides[iit]);
484+ const size_type output_xyz_id = get_xyz_id_by_id_inkernel (output_global_id, output_shape_strides,
485+ output_shape_size, oit);
486+ input_global_id += (output_xyz_id * input_shape_strides[iit]);
498487 ++oit;
499488 }
500489 }
501490 }
502491 else if (broadcast_use)
503492 {
504493 assert (output_global_id < output_size);
505-
506- // use thread private storage
507- size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
508-
509- get_xyz_by_id_inkernel (output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
494+ assert (input_shape_size <= output_shape_size);
510495
511496 for (int irit = input_shape_size - 1 , orit = output_shape_size - 1 ; irit >= 0 ; --irit, --orit)
512497 {
513498 size_type* broadcast_axes_end = broadcast_axes + broadcast_axes_size;
514499 if (std::find (broadcast_axes, broadcast_axes_end, orit) == broadcast_axes_end)
515500 {
516- input_global_id += (sycl_output_xyz_thread[orit] * input_shape_strides[irit]);
501+ const size_type output_xyz_id = get_xyz_id_by_id_inkernel (output_global_id, output_shape_strides,
502+ output_shape_size, orit);
503+ input_global_id += (output_xyz_id * input_shape_strides[irit]);
517504 }
518505 }
519506 }
@@ -565,10 +552,8 @@ class DPNPC_id final
565552 output_shape_size = size_type{};
566553 dpnp_memory_free_c (output_shape);
567554 dpnp_memory_free_c (output_shape_strides);
568- dpnp_memory_free_c (sycl_output_xyz);
569555 output_shape = nullptr ;
570556 output_shape_strides = nullptr ;
571- sycl_output_xyz = nullptr ;
572557 }
573558
574559 void free_memory ()
@@ -602,9 +587,6 @@ class DPNPC_id final
602587 size_type iteration_shape_size = size_type{};
603588 size_type* iteration_shape_strides = nullptr ;
604589 size_type* axes_shape_strides = nullptr ;
605-
606- // data allocated to use inside SYCL kernels
607- size_type* sycl_output_xyz = nullptr ;
608590};
609591
610592#endif // DPNP_ITERATOR_H
0 commit comments