2929#include < dpnp_iface.hpp>
3030#include " dpnp_fptr.hpp"
3131#include " dpnp_utils.hpp"
32+ #include " dpnpc_memory_adapter.hpp"
3233#include " queue_sycl.hpp"
3334
3435namespace mkl_blas = oneapi::mkl::blas::row_major;
@@ -39,8 +40,10 @@ void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const si
3940{
4041 cl::sycl::event event;
4142
42- _DataType* in_array = reinterpret_cast <_DataType*>(array1_in);
43- _DataType* result = reinterpret_cast <_DataType*>(result1);
43+ DPNPC_ptr_adapter<_DataType> input1_ptr (array1_in, size, true );
44+ DPNPC_ptr_adapter<_DataType> result_ptr (result1, size, true , true );
45+ _DataType* in_array = input1_ptr.get_ptr ();
46+ _DataType* result = result_ptr.get_ptr ();
4447
4548 size_t iters = size / (data_size * data_size);
4649
@@ -97,8 +100,11 @@ void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const si
97100template <typename _DataType>
98101void dpnp_det_c (void * array1_in, void * result1, size_t * shape, size_t ndim)
99102{
100- _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
101- _DataType* result = reinterpret_cast <_DataType*>(result1);
103+ const size_t input_size = std::accumulate (shape, shape + ndim, 1 , std::multiplies<size_t >());
104+ if (!input_size)
105+ {
106+ return ;
107+ }
102108
103109 size_t n = shape[ndim - 1 ];
104110 size_t size_out = 1 ;
@@ -110,6 +116,11 @@ void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
110116 }
111117 }
112118
119+ DPNPC_ptr_adapter<_DataType> input1_ptr (array1_in, input_size, true );
120+ DPNPC_ptr_adapter<_DataType> result_ptr (result1, size_out, true , true );
121+ _DataType* array_1 = input1_ptr.get_ptr ();
122+ _DataType* result = result_ptr.get_ptr ();
123+
113124 for (size_t i = 0 ; i < size_out; i++)
114125 {
115126 _DataType matrix[n][n];
@@ -194,8 +205,17 @@ template <typename _DataType, typename _ResultType>
194205void dpnp_inv_c (void * array1_in, void * result1, size_t * shape, size_t ndim)
195206{
196207 (void )ndim; // avoid warning unused variable
197- _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
198- _ResultType* result = reinterpret_cast <_ResultType*>(result1);
208+
209+ const size_t input_size = std::accumulate (shape, shape + ndim, 1 , std::multiplies<size_t >());
210+ if (!input_size)
211+ {
212+ return ;
213+ }
214+
215+ DPNPC_ptr_adapter<_DataType> input1_ptr (array1_in, input_size, true );
216+ DPNPC_ptr_adapter<_ResultType> result_ptr (result1, input_size, true , true );
217+ _DataType* array_1 = input1_ptr.get_ptr ();
218+ _ResultType* result = result_ptr.get_ptr ();
199219
200220 size_t n = shape[0 ];
201221
@@ -298,16 +318,21 @@ void dpnp_kron_c(void* array1_in,
298318 size_t * res_shape,
299319 size_t ndim)
300320{
301- _DataType1* array1 = reinterpret_cast <_DataType1*>(array1_in);
302- _DataType2* array2 = reinterpret_cast <_DataType2*>(array2_in);
303- _ResultType* result = reinterpret_cast <_ResultType*>(result1);
304-
305- size_t size = 1 ;
306- for (size_t i = 0 ; i < ndim; ++i)
321+ const size_t input1_size = std::accumulate (in1_shape, in1_shape + ndim, 1 , std::multiplies<size_t >());
322+ const size_t input2_size = std::accumulate (in2_shape, in2_shape + ndim, 1 , std::multiplies<size_t >());
323+ const size_t result_size = std::accumulate (res_shape, res_shape + ndim, 1 , std::multiplies<size_t >());
324+ if (!(result_size && input1_size && input2_size))
307325 {
308- size *= res_shape[i] ;
326+ return ;
309327 }
310328
329+ DPNPC_ptr_adapter<_DataType1> input1_ptr (array1_in, input1_size);
330+ DPNPC_ptr_adapter<_DataType2> input2_ptr (array2_in, input2_size);
331+ DPNPC_ptr_adapter<_ResultType> result_ptr (result1, result_size);
332+ _DataType1* array1 = input1_ptr.get_ptr ();
333+ _DataType2* array2 = input2_ptr.get_ptr ();
334+ _ResultType* result = result_ptr.get_ptr ();
335+
311336 size_t * _in1_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
312337 size_t * _in2_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
313338
@@ -322,7 +347,7 @@ void dpnp_kron_c(void* array1_in,
322347 get_shape_offsets_inkernel<size_t >(in2_shape, ndim, in2_offsets);
323348 get_shape_offsets_inkernel<size_t >(res_shape, ndim, res_offsets);
324349
325- cl::sycl::range<1 > gws (size );
350+ cl::sycl::range<1 > gws (result_size );
326351 auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
327352 const size_t idx = global_id[0 ];
328353
@@ -356,12 +381,18 @@ void dpnp_kron_c(void* array1_in,
356381template <typename _DataType>
357382void dpnp_matrix_rank_c (void * array1_in, void * result1, size_t * shape, size_t ndim)
358383{
359- _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
360- _DataType* result = reinterpret_cast <_DataType*>(result1);
384+ const size_t input_size = std::accumulate (shape, shape + ndim, 1 , std::multiplies<size_t >());
385+ if (!input_size)
386+ {
387+ return ;
388+ }
389+
390+ DPNPC_ptr_adapter<_DataType> input1_ptr (array1_in, input_size);
391+ DPNPC_ptr_adapter<_DataType> result_ptr (result1, 1 );
392+ _DataType* array_1 = input1_ptr.get_ptr ();
393+ _DataType* result = result_ptr.get_ptr ();
361394
362395 size_t elems = 1 ;
363- const _DataType init_val = 0 ;
364- dpnp_memory_memcpy_c (result, &init_val, sizeof (_DataType)); // result[0] = 0;
365396 if (ndim > 1 )
366397 {
367398 elems = shape[0 ];
@@ -373,15 +404,18 @@ void dpnp_matrix_rank_c(void* array1_in, void* result1, size_t* shape, size_t nd
373404 }
374405 }
375406 }
407+
408+ _DataType acc = 0 ;
376409 for (size_t i = 0 ; i < elems; i++)
377410 {
378411 size_t ind = 0 ;
379412 for (size_t j = 0 ; j < ndim; j++)
380413 {
381414 ind += (shape[j] - 1 ) * i;
382415 }
383- result[ 0 ] += array_1[ind];
416+ acc += array_1[ind];
384417 }
418+ result[0 ] = acc;
385419
386420 return ;
387421}
@@ -391,7 +425,8 @@ void dpnp_qr_c(void* array1_in, void* result1, void* result2, void* result3, siz
391425{
392426 cl::sycl::event event;
393427
394- _InputDT* in_array = reinterpret_cast <_InputDT*>(array1_in);
428+ DPNPC_ptr_adapter<_InputDT> input1_ptr (array1_in, size_m * size_n, true );
429+ _InputDT* in_array = input1_ptr.get_ptr ();
395430
396431 // math lib func overrides input
397432 _ComputeDT* in_a = reinterpret_cast <_ComputeDT*>(dpnp_memory_alloc_c (size_m * size_n * sizeof (_ComputeDT)));
@@ -400,13 +435,17 @@ void dpnp_qr_c(void* array1_in, void* result1, void* result2, void* result3, siz
400435 {
401436 for (size_t j = 0 ; j < size_n; ++j)
402437 {
438+ // TODO transpose? use dpnp_transpose_c()
403439 in_a[j * size_m + i] = in_array[i * size_n + j];
404440 }
405441 }
406442
407- _ComputeDT* res_q = reinterpret_cast <_ComputeDT*>(result1);
408- _ComputeDT* res_r = reinterpret_cast <_ComputeDT*>(result2);
409- _ComputeDT* tau = reinterpret_cast <_ComputeDT*>(result3);
443+ DPNPC_ptr_adapter<_ComputeDT> result1_ptr (result1, size_m * size_m, true , true );
444+ DPNPC_ptr_adapter<_ComputeDT> result2_ptr (result2, size_m * size_n, true , true );
445+ DPNPC_ptr_adapter<_ComputeDT> result3_ptr (result3, std::min (size_m, size_n), true , true );
446+ _ComputeDT* res_q = result1_ptr.get_ptr ();
447+ _ComputeDT* res_r = result2_ptr.get_ptr ();
448+ _ComputeDT* tau = result3_ptr.get_ptr ();
410449
411450 const std::int64_t lda = size_m;
412451
@@ -487,18 +526,22 @@ void dpnp_svd_c(void* array1_in, void* result1, void* result2, void* result3, si
487526{
488527 cl::sycl::event event;
489528
490- _InputDT* in_array = reinterpret_cast <_InputDT*>(array1_in);
529+ DPNPC_ptr_adapter<_InputDT> input1_ptr (array1_in, size_m * size_n, true ); // TODO no need this if use dpnp_copy_to()
530+ _InputDT* in_array = input1_ptr.get_ptr ();
491531
492532 // math lib gesvd func overrides input
493533 _ComputeDT* in_a = reinterpret_cast <_ComputeDT*>(dpnp_memory_alloc_c (size_m * size_n * sizeof (_ComputeDT)));
494534 for (size_t it = 0 ; it < size_m * size_n; ++it)
495535 {
496- in_a[it] = in_array[it];
536+ in_a[it] = in_array[it]; // TODO Type conversion. memcpy can not be used directly. dpnp_copy_to() ?
497537 }
498538
499- _ComputeDT* res_u = reinterpret_cast <_ComputeDT*>(result1);
500- _SVDT* res_s = reinterpret_cast <_SVDT*>(result2);
501- _ComputeDT* res_vt = reinterpret_cast <_ComputeDT*>(result3);
539+ DPNPC_ptr_adapter<_ComputeDT> result1_ptr (result1, size_m * size_m, true , true );
540+ DPNPC_ptr_adapter<_SVDT> result2_ptr (result2, std::min (size_m, size_n), true , true );
541+ DPNPC_ptr_adapter<_ComputeDT> result3_ptr (result3, size_n * size_n, true , true );
542+ _ComputeDT* res_u = result1_ptr.get_ptr ();
543+ _SVDT* res_s = result2_ptr.get_ptr ();
544+ _ComputeDT* res_vt = result3_ptr.get_ptr ();
502545
503546 const std::int64_t m = size_m;
504547 const std::int64_t n = size_n;
0 commit comments