@@ -346,50 +346,80 @@ void dpnp_remainder_c(void* result_out,
346346 const size_t input2_shape_ndim,
347347 const size_t * where)
348348{
349- (void )input1_shape;
350- (void )input1_shape_ndim;
351- (void )input2_size;
352- (void )input2_shape;
353- (void )input2_shape_ndim;
354349 (void )where;
355350
356- cl::sycl::event event;
357- _DataType_input1* input1 = reinterpret_cast <_DataType_input1*>(const_cast <void *>(input1_in));
358- _DataType_input2* input2 = reinterpret_cast <_DataType_input2*>(const_cast <void *>(input2_in));
351+ if (!input1_size || !input2_size)
352+ {
353+ return ;
354+ }
355+
356+ _DataType_input1* input1_data = reinterpret_cast <_DataType_input1*>(const_cast <void *>(input1_in));
357+ _DataType_input2* input2_data = reinterpret_cast <_DataType_input2*>(const_cast <void *>(input2_in));
359358 _DataType_output* result = reinterpret_cast <_DataType_output*>(result_out);
360359
361- if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
362- std::is_same<_DataType_input2, _DataType_input1>::value)
360+ std::vector<size_t > result_shape = get_result_shape (input1_shape, input1_shape_ndim,
361+ input2_shape, input2_shape_ndim);
362+
363+ DPNPC_id<_DataType_input1>* input1_it;
364+ const size_t input1_it_size_in_bytes = sizeof (DPNPC_id<_DataType_input1>);
365+ input1_it = reinterpret_cast <DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c (input1_it_size_in_bytes));
366+ new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
367+
368+ input1_it->broadcast_to_shape (result_shape);
369+
370+ DPNPC_id<_DataType_input2>* input2_it;
371+ const size_t input2_it_size_in_bytes = sizeof (DPNPC_id<_DataType_input2>);
372+ input2_it = reinterpret_cast <DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c (input2_it_size_in_bytes));
373+ new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
374+
375+ input2_it->broadcast_to_shape (result_shape);
376+
377+ const size_t result_size = input1_it->get_output_size ();
378+
379+ cl::sycl::range<1 > gws (result_size);
380+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
381+ const size_t i = global_id[0 ];
382+ const _DataType_output input1_elem = (*input1_it)[i];
383+ const _DataType_output input2_elem = (*input2_it)[i];
384+ double fmod_res = cl::sycl::fmod ((double )input1_elem, (double )input2_elem);
385+ double add = fmod_res + input2_elem;
386+ result[i] = cl::sycl::fmod (add, (double )input2_elem);
387+
388+ };
389+ auto kernel_func = [&](cl::sycl::handler& cgh) {
390+ cgh.parallel_for <class dpnp_remainder_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
391+ gws, kernel_parallel_for_func);
392+ };
393+
394+ cl::sycl::event event;
395+
396+ if (input1_size == input2_size)
363397 {
364- event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, input1, input2, result);
365- event.wait ();
366- event = oneapi::mkl::vm::add (DPNP_QUEUE, input1_size, result, input2, result);
367- event.wait ();
368- event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, result, input2, result);
398+ if constexpr ((std::is_same<_DataType_input1, double >::value ||
399+ std::is_same<_DataType_input1, float >::value) &&
400+ std::is_same<_DataType_input2, _DataType_input1>::value)
401+ {
402+ event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, input1_data, input2_data, result);
403+ event.wait ();
404+ event = oneapi::mkl::vm::add (DPNP_QUEUE, input1_size, result, input2_data, result);
405+ event.wait ();
406+ event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, result, input2_data, result);
407+ }
408+ else
409+ {
410+ event = DPNP_QUEUE.submit (kernel_func);
411+ }
369412 }
370413 else
371414 {
372- cl::sycl::range<1 > gws (input1_size);
373- auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
374- size_t i = global_id[0 ]; /* for (size_t i = 0; i < size; ++i)*/
375- {
376- _DataType_input1 input_elem1 = input1[i];
377- _DataType_input2 input_elem2 = input2[i];
378- double fmod = cl::sycl::fmod ((double )input_elem1, (double )input_elem2);
379- double add = fmod + input_elem2;
380- result[i] = cl::sycl::fmod (add, (double )input_elem2);
381- }
382- };
383-
384- auto kernel_func = [&](cl::sycl::handler& cgh) {
385- cgh.parallel_for <class dpnp_remainder_c_kernel <_DataType_input1, _DataType_input2, _DataType_output>>(
386- gws, kernel_parallel_for_func);
387- };
388-
389415 event = DPNP_QUEUE.submit (kernel_func);
390416 }
391417
392418 event.wait ();
419+
420+ input1_it->~DPNPC_id ();
421+ input2_it->~DPNPC_id ();
422+
393423}
394424
395425template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
0 commit comments