@@ -533,7 +533,8 @@ template <typename _KernelNameSpecialization>
533533class dpnp_matmul_c_kernel ;
534534
535535template <typename _DataType>
536- void dpnp_matmul_c (void * result_out,
536+ void dpnp_matmul_c (DPCTLSyclQueueRef q_ref,
537+ void * result_out,
537538 const size_t result_size,
538539 const size_t result_ndim,
539540 const shape_elem_type* result_shape,
@@ -569,13 +570,12 @@ void dpnp_matmul_c(void* result_out,
569570 return ;
570571 }
571572
573+ sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
572574 sycl::event event;
573- DPNPC_ptr_adapter<_DataType> input1_ptr (input1_in, size_m * size_k);
574- DPNPC_ptr_adapter<_DataType> input2_ptr (input2_in, size_k * size_n);
575- DPNPC_ptr_adapter<_DataType> result_ptr (result_out, size_m * size_n, false , true );
576- _DataType* array_1 = input1_ptr.get_ptr ();
577- _DataType* array_2 = input2_ptr.get_ptr ();
578- _DataType* result = result_ptr.get_ptr ();
575+
576+ _DataType* array_1 = reinterpret_cast <_DataType*>(const_cast <void *>(input1_in));
577+ _DataType* array_2 = reinterpret_cast <_DataType*>(const_cast <void *>(input2_in));
578+ _DataType* result = reinterpret_cast <_DataType*>(result_out);
579579
580580 if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
581581 {
@@ -584,7 +584,7 @@ void dpnp_matmul_c(void* result_out,
584584 const std::int64_t ldb = std::max<size_t >(1UL , size_n); // First dimensions of array_2
585585 const std::int64_t ldc = std::max<size_t >(1UL , size_n); // Fast dimensions of result
586586
587- event = mkl_blas::gemm (DPNP_QUEUE ,
587+ event = mkl_blas::gemm (q ,
588588 oneapi::mkl::transpose::nontrans,
589589 oneapi::mkl::transpose::nontrans,
590590 size_n,
@@ -632,11 +632,70 @@ void dpnp_matmul_c(void* result_out,
632632 cgh.parallel_for <class dpnp_matmul_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
633633 };
634634
635- event = DPNP_QUEUE .submit (kernel_func);
635+ event = q .submit (kernel_func);
636636 }
637637 event.wait ();
638638}
639639
640+ template <typename _DataType>
641+ void dpnp_matmul_c (void * result_out,
642+ const size_t result_size,
643+ const size_t result_ndim,
644+ const shape_elem_type* result_shape,
645+ const shape_elem_type* result_strides,
646+ const void * input1_in,
647+ const size_t input1_size,
648+ const size_t input1_ndim,
649+ const shape_elem_type* input1_shape,
650+ const shape_elem_type* input1_strides,
651+ const void * input2_in,
652+ const size_t input2_size,
653+ const size_t input2_ndim,
654+ const shape_elem_type* input2_shape,
655+ const shape_elem_type* input2_strides)
656+ {
657+ DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
658+ dpnp_matmul_c<_DataType>(q_ref,
659+ result_out, result_size, result_ndim, result_shape, result_strides,
660+ input1_in, input1_size, input1_ndim, input1_shape, input1_strides,
661+ input2_in, input2_size, input2_ndim, input2_shape, input2_strides);
662+ }
663+
664+ template <typename _DataType>
665+ void (*dpnp_matmul_default_c)(void *,
666+ const size_t ,
667+ const size_t ,
668+ const shape_elem_type*,
669+ const shape_elem_type*,
670+ const void *,
671+ const size_t ,
672+ const size_t ,
673+ const shape_elem_type*,
674+ const shape_elem_type*,
675+ const void *,
676+ const size_t ,
677+ const size_t ,
678+ const shape_elem_type*,
679+ const shape_elem_type*) = dpnp_matmul_c<_DataType>;
680+
681+ template <typename _DataType>
682+ void (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
683+ void *,
684+ const size_t ,
685+ const size_t ,
686+ const shape_elem_type*,
687+ const shape_elem_type*,
688+ const void *,
689+ const size_t ,
690+ const size_t ,
691+ const shape_elem_type*,
692+ const shape_elem_type*,
693+ const void *,
694+ const size_t ,
695+ const size_t ,
696+ const shape_elem_type*,
697+ const shape_elem_type*) = dpnp_matmul_c<_DataType>;
698+
640699void func_map_init_linalg (func_map_t & fmap)
641700{
642701 fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_astype_c<bool , bool >};
@@ -702,10 +761,15 @@ void func_map_init_linalg(func_map_t& fmap)
702761 fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_initval_c<double >};
703762 fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_C128][eft_C128] = {eft_C128, (void *)dpnp_initval_c<std::complex <double >>};
704763
705- fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_matmul_c<int32_t >};
706- fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_matmul_c<int64_t >};
707- fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_matmul_c<float >};
708- fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_matmul_c<double >};
764+ fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_matmul_default_c<int32_t >};
765+ fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_matmul_default_c<int64_t >};
766+ fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_matmul_default_c<float >};
767+ fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_matmul_default_c<double >};
768+
769+ fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_matmul_ext_c<int32_t >};
770+ fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_matmul_ext_c<int64_t >};
771+ fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_matmul_ext_c<float >};
772+ fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_matmul_ext_c<double >};
709773
710774 return ;
711775}
0 commit comments