@@ -215,7 +215,7 @@ template <typename _DataType, typename _ResultType>
215215class dpnp_trace_c_kernel ;
216216
217217template <typename _DataType, typename _ResultType>
218- void dpnp_trace_c (const void * array1_in, void * result1, const size_t * shape , const size_t ndim)
218+ void dpnp_trace_c (const void * array1_in, void * result1, const size_t * shape_ , const size_t ndim)
219219{
220220 cl::sycl::event event;
221221
@@ -227,7 +227,7 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
227227 const _DataType* array_in = reinterpret_cast <const _DataType*>(array1_in);
228228 _ResultType* result = reinterpret_cast <_ResultType*>(result1);
229229
230- if (shape == nullptr )
230+ if (shape_ == nullptr )
231231 {
232232 return ;
233233 }
@@ -240,32 +240,37 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
240240 size_t size = 1 ;
241241 for (size_t i = 0 ; i < ndim - 1 ; ++i)
242242 {
243- size *= shape [i];
243+ size *= shape_ [i];
244244 }
245245
246246 if (size == 0 )
247247 {
248248 return ;
249249 }
250250
251+ size_t * shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
252+ auto memcpy_event = DPNP_QUEUE.memcpy (shape, shape_, ndim * sizeof (size_t ));
253+
251254 cl::sycl::range<1 > gws (size);
252255 auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
253256 size_t i = global_id[0 ];
254- _DataType elem = 0 ;
257+ result[i] = 0 ;
255258 for (size_t j = 0 ; j < shape[ndim - 1 ]; ++j)
256259 {
257- elem += array_in[i * shape[ndim - 1 ] + j];
260+ result[i] += array_in[i * shape[ndim - 1 ] + j];
258261 }
259- result[i] = elem;
260262 };
261263
262264 auto kernel_func = [&](cl::sycl::handler& cgh) {
265+ cgh.depends_on ({memcpy_event});
263266 cgh.parallel_for <class dpnp_trace_c_kernel <_DataType, _ResultType>>(gws, kernel_parallel_for_func);
264267 };
265268
266269 event = DPNP_QUEUE.submit (kernel_func);
267270
268271 event.wait ();
272+
273+ dpnp_memory_free_c (shape);
269274}
270275
271276template <typename _DataType>
0 commit comments