@@ -533,22 +533,23 @@ template <typename _KernelNameSpecialization>
533533class dpnp_matmul_c_kernel ;
534534
535535template <typename _DataType>
536- void dpnp_matmul_c (DPCTLSyclQueueRef q_ref,
537- void * result_out,
538- const size_t result_size,
539- const size_t result_ndim,
540- const shape_elem_type* result_shape,
541- const shape_elem_type* result_strides,
542- const void * input1_in,
543- const size_t input1_size,
544- const size_t input1_ndim,
545- const shape_elem_type* input1_shape,
546- const shape_elem_type* input1_strides,
547- const void * input2_in,
548- const size_t input2_size,
549- const size_t input2_ndim,
550- const shape_elem_type* input2_shape,
551- const shape_elem_type* input2_strides)
536+ DPCTLSyclEventRef dpnp_matmul_c (DPCTLSyclQueueRef q_ref,
537+ void * result_out,
538+ const size_t result_size,
539+ const size_t result_ndim,
540+ const shape_elem_type* result_shape,
541+ const shape_elem_type* result_strides,
542+ const void * input1_in,
543+ const size_t input1_size,
544+ const size_t input1_ndim,
545+ const shape_elem_type* input1_shape,
546+ const shape_elem_type* input1_strides,
547+ const void * input2_in,
548+ const size_t input2_size,
549+ const size_t input2_ndim,
550+ const shape_elem_type* input2_shape,
551+ const shape_elem_type* input2_strides,
552+ const DPCTLEventVectorRef dep_event_vec_ref)
552553{
553554 (void )result_size;
554555 (void )result_ndim;
@@ -561,16 +562,19 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
561562 (void )input2_ndim;
562563 (void )input2_strides;
563564
565+ DPCTLSyclEventRef event_ref = nullptr ;
566+
564567 size_t size_m = input1_shape[0 ];
565568 size_t size_n = input2_shape[1 ];
566569 size_t size_k = input1_shape[1 ];
567570
568571 if (!size_m || !size_n || !size_k)
569572 {
570- return ;
573+ return event_ref ;
571574 }
572575
573576 sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
577+ std::vector<sycl::event> dep_events = cast_event_vector (dep_event_vec_ref);
574578 sycl::event event;
575579
576580 _DataType* array_1 = reinterpret_cast <_DataType*>(const_cast <void *>(input1_in));
@@ -597,7 +601,8 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
597601 lda,
598602 _DataType (0 ),
599603 result,
600- ldc);
604+ ldc,
605+ dep_events);
601606 }
602607 else
603608 {
@@ -629,12 +634,16 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
629634 };
630635
631636 auto kernel_func = [&](sycl::handler& cgh) {
637+ cgh.depends_on (dep_events);
632638 cgh.parallel_for <class dpnp_matmul_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
633639 };
634640
635641 event = q.submit (kernel_func);
636642 }
637- event.wait ();
643+
644+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
645+
646+ return DPCTLEvent_Copy (event_ref);
638647}
639648
640649template <typename _DataType>
@@ -655,10 +664,26 @@ void dpnp_matmul_c(void* result_out,
655664 const shape_elem_type* input2_strides)
656665{
657666 DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
658- dpnp_matmul_c<_DataType>(q_ref,
659- result_out, result_size, result_ndim, result_shape, result_strides,
660- input1_in, input1_size, input1_ndim, input1_shape, input1_strides,
661- input2_in, input2_size, input2_ndim, input2_shape, input2_strides);
667+ DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
668+ DPCTLSyclEventRef event_ref = dpnp_matmul_c<_DataType>(q_ref,
669+ result_out,
670+ result_size,
671+ result_ndim,
672+ result_shape,
673+ result_strides,
674+ input1_in,
675+ input1_size,
676+ input1_ndim,
677+ input1_shape,
678+ input1_strides,
679+ input2_in,
680+ input2_size,
681+ input2_ndim,
682+ input2_shape,
683+ input2_strides,
684+ dep_event_vec_ref);
685+ sycl::event event = *(reinterpret_cast <sycl::event*>(event_ref));
686+ event.wait_and_throw ();
662687}
663688
664689template <typename _DataType>
@@ -679,22 +704,23 @@ void (*dpnp_matmul_default_c)(void*,
679704 const shape_elem_type*) = dpnp_matmul_c<_DataType>;
680705
681706template <typename _DataType>
682- void (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
683- void *,
684- const size_t ,
685- const size_t ,
686- const shape_elem_type*,
687- const shape_elem_type*,
688- const void *,
689- const size_t ,
690- const size_t ,
691- const shape_elem_type*,
692- const shape_elem_type*,
693- const void *,
694- const size_t ,
695- const size_t ,
696- const shape_elem_type*,
697- const shape_elem_type*) = dpnp_matmul_c<_DataType>;
707+ DPCTLSyclEventRef (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
708+ void *,
709+ const size_t ,
710+ const size_t ,
711+ const shape_elem_type*,
712+ const shape_elem_type*,
713+ const void *,
714+ const size_t ,
715+ const size_t ,
716+ const shape_elem_type*,
717+ const shape_elem_type*,
718+ const void *,
719+ const size_t ,
720+ const size_t ,
721+ const shape_elem_type*,
722+ const shape_elem_type*,
723+ const DPCTLEventVectorRef) = dpnp_matmul_c<_DataType>;
698724
699725void func_map_init_linalg (func_map_t & fmap)
700726{
0 commit comments