3434#include < dpnp_iface.hpp>
3535
3636namespace mkl_blas = oneapi::mkl::blas;
37+ namespace mkl_blas_cm = oneapi::mkl::blas::column_major;
3738namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
3839namespace mkl_lapack = oneapi::mkl::lapack;
3940
@@ -227,12 +228,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
227228 DPCTLSyclEventRef event_ref = nullptr ;
228229 sycl::queue q = *(reinterpret_cast <sycl::queue *>(q_ref));
229230
230- DPNPC_ptr_adapter<_DataType_input1> input1_ptr (q_ref, input1_in,
231- input1_size);
232- DPNPC_ptr_adapter<_DataType_input2> input2_ptr (q_ref, input2_in,
233- input2_size);
234- _DataType_input1 *input1 = input1_ptr.get_ptr ();
235- _DataType_input2 *input2 = input2_ptr.get_ptr ();
231+ _DataType_input1 *input1 =
232+ static_cast <_DataType_input1 *>(const_cast <void *>(input1_in));
233+ _DataType_input2 *input2 =
234+ static_cast <_DataType_input2 *>(const_cast <void *>(input2_in));
236235 _DataType_output *result = reinterpret_cast <_DataType_output *>(result_out);
237236
238237 if (!input1_size || !input2_size) {
@@ -257,10 +256,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
257256 // if both arrays are vectors
258257 if ((input1_ndim == 1 ) && (input2_ndim == 1 )) {
259258 assert (input1_size == input2_size);
259+
260260 sycl::event event = dot (q, result, input1, input2, input1_strides[0 ],
261261 input2_strides[0 ], input1_size);
262- event.wait ();
263- return event_ref;
262+
263+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
264+ return DPCTLEvent_Copy (event_ref);
264265 }
265266
266267 // 1D vector
@@ -297,13 +298,17 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
297298 size_t ext_result_ndim =
298299 ((input1_ndim == 1 ) || (input2_ndim == 1 )) ? 2 : result_ndim;
299300 shape_elem_type *ext_result_shape = new shape_elem_type[ext_result_ndim];
301+ shape_elem_type *ext_result_strides = new shape_elem_type[ext_result_ndim];
300302 if ((input1_ndim == 1 ) || (input2_ndim == 1 )) {
301303 ext_result_shape[0 ] = ext_input1_shape[0 ];
302304 ext_result_shape[1 ] = ext_input2_shape[1 ];
305+ ext_result_strides[0 ] = 0 ;
306+ ext_result_strides[1 ] = result_strides[0 ];
303307 }
304308 else {
305309 for (size_t i = 0 ; i < ext_result_ndim; ++i) {
306310 ext_result_shape[i] = result_shape[i];
311+ ext_result_strides[i] = result_strides[i];
307312 }
308313 }
309314
@@ -316,80 +321,89 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
316321 // check if GEMM can be executed (strides)
317322 // TODO: rewrite the condition in general case for ndims > 2
318323 // (looks like there are such another cases)
319-
320324 if (ext_input1_ndim == 2 && ext_input2_ndim == 2 ) {
321- // there is a difference of behavior with trans and sizes params in previous
322- // version of GEMM only new version is supported, in case of old version
323- // computation goes in common way
324- #if INTEL_MKL_VERSION >= 20210004
325- // is mat1 F-contiguous, C-contiguous
326- bool mat1_f_contig =
327- (((ext_input1_shape[0 ] == 1 ) || (ext_input1_strides[0 ] == 1 )) &&
328- ((ext_input1_shape[1 ] == 1 ) ||
329- (ext_input1_strides[1 ] == ext_input1_shape[0 ])));
330- bool mat1_c_contig =
331- (((ext_input1_shape[1 ] == 1 ) || (ext_input1_strides[1 ] == 1 )) &&
332- ((ext_input1_shape[0 ] == 1 ) ||
333- (ext_input1_strides[0 ] == ext_input1_shape[1 ])));
334- // is mat2 F-contiguous, C-contiguous
335- bool mat2_f_contig =
336- (((ext_input2_shape[0 ] == 1 ) || (ext_input2_strides[0 ] == 1 )) &&
337- ((ext_input2_shape[1 ] == 1 ) ||
338- (ext_input2_strides[1 ] == ext_input2_shape[0 ])));
339- bool mat2_c_contig =
340- (((ext_input2_shape[1 ] == 1 ) || (ext_input2_strides[1 ] == 1 )) &&
341- ((ext_input2_shape[0 ] == 1 ) ||
342- (ext_input2_strides[0 ] == ext_input2_shape[1 ])));
343-
344- if ((mat1_f_contig || mat1_c_contig) &&
345- (mat2_f_contig || mat2_c_contig)) {
346- oneapi::mkl::transpose trans1 =
347- (mat1_f_contig && !mat1_c_contig)
348- ? oneapi::mkl::transpose::trans
349- : oneapi::mkl::transpose::nontrans;
350- oneapi::mkl::transpose trans2 =
351- (mat2_f_contig && !mat2_c_contig)
352- ? oneapi::mkl::transpose::trans
353- : oneapi::mkl::transpose::nontrans;
325+ // OneMKL gemm suports only arrays contiguous on inner dimension,
326+ // so stride for at least one dimension should be equal to 1
327+ if ((ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
328+ (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 ) &&
329+ (ext_result_strides[0 ] == 1 || ext_result_strides[1 ] == 1 ))
330+ {
331+ const bool isRowmA =
332+ (ext_input1_strides[1 ] == 1 || ext_input1_strides[0 ] == 0 );
333+ const bool isRowmB =
334+ (ext_input2_strides[1 ] == 1 || ext_input2_strides[1 ] == 0 );
335+ const bool isRowmC =
336+ (ext_result_strides[1 ] == 1 || ext_result_strides[0 ] == 0 );
337+
338+ oneapi::mkl::transpose transA =
339+ (isRowmA != isRowmC) ? oneapi::mkl::transpose::trans
340+ : oneapi::mkl::transpose::nontrans;
341+ oneapi::mkl::transpose transB =
342+ (isRowmB != isRowmC) ? oneapi::mkl::transpose::trans
343+ : oneapi::mkl::transpose::nontrans;
354344
355345 const size_t size_m = ext_input1_shape[0 ];
356346 const size_t size_n = ext_input2_shape[1 ];
357347 const size_t size_k = ext_input1_shape[1 ];
358348
359- const std::int64_t lda =
360- trans1 == oneapi::mkl::transpose::nontrans
361- ? ext_input1_strides[0 ]
362- : ext_input1_strides[1 ];
363- const std::int64_t ldb =
364- trans2 == oneapi::mkl::transpose::nontrans
365- ? ext_input2_strides[0 ]
366- : ext_input2_strides[1 ];
367-
368- // definition of ldc will be another for result with
369- // non-standard (c-contiguous) strides const std::int64_t ldc =
370- // result_strides[0] == 1 ? result_strides[1] :
371- // result_strides[0];
372- const std::int64_t ldc = size_n;
349+ auto getLdaLdc = [](const bool isRown, shape_elem_type *strides,
350+ shape_elem_type *shapes) {
351+ if (isRown) {
352+ return (strides[0 ] != 0 ) ? strides[0 ] : shapes[1 ];
353+ }
354+ return strides[1 ];
355+ };
356+
357+ const std::int64_t lda = static_cast <std::int64_t >(
358+ getLdaLdc (isRowmA, ext_input1_strides, ext_input1_shape));
359+ const std::int64_t ldb = static_cast <std::int64_t >(
360+ isRowmB ? ext_input2_strides[0 ] : ext_input2_strides[1 ]);
361+ const std::int64_t ldc = static_cast <std::int64_t >(
362+ getLdaLdc (isRowmC, ext_result_strides, ext_result_shape));
363+
364+ constexpr _DataType_output alpha = 1 ;
365+ constexpr _DataType_output beta = 0 ;
366+
367+ std::stringstream error_msg;
368+ std::int64_t info = 0 ;
373369
374370 try {
375- sycl::event event = mkl_blas_rm::gemm (
376- q, trans1, trans2, size_m, size_n, size_k,
377- _DataType_output (1 ), // alpha
378- input1, lda, input2, ldb,
379- _DataType_output (0 ), // beta
380- result, ldc);
381- event.wait ();
382- delete[] ext_input1_shape;
383- delete[] ext_input1_strides;
384- delete[] ext_input2_shape;
385- delete[] ext_input2_strides;
386- delete[] ext_result_shape;
387-
388- return event_ref;
371+ if (isRowmC) {
372+ mkl_blas_rm::gemm (q, transA, transB, size_m, size_n,
373+ size_k, alpha, input1, lda, input2,
374+ ldb, beta, result, ldc)
375+ .wait ();
376+ }
377+ else {
378+ mkl_blas_cm::gemm (q, transA, transB, size_m, size_n,
379+ size_k, alpha, input1, lda, input2,
380+ ldb, beta, result, ldc)
381+ .wait ();
382+ }
383+ } catch (mkl_lapack::exception const &e) {
384+ error_msg << " Unexpected MKL exception caught during "
385+ " gemm() call:\n reason: "
386+ << e.what () << " \n info: " << e.info ();
387+ info = e.info ();
389388 } catch (const std::exception &e) {
390- // do nothing, proceed to general case
389+ error_msg << " Unexpected SYCL exception caught during "
390+ " gemm() call:\n "
391+ << e.what ();
392+ info = -1 ;
391393 }
392- #endif
394+
395+ if (info != 0 ) // an unexected error occurs
396+ {
397+ throw std::runtime_error (error_msg.str ());
398+ }
399+
400+ delete[] ext_input1_shape;
401+ delete[] ext_input1_strides;
402+ delete[] ext_input2_shape;
403+ delete[] ext_input2_strides;
404+ delete[] ext_result_shape;
405+ delete[] ext_result_strides;
406+ return event_ref;
393407 }
394408 }
395409 }
@@ -437,6 +451,7 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
437451 delete[] ext_input2_shape;
438452 delete[] ext_input2_strides;
439453 delete[] ext_result_shape;
454+ delete[] ext_result_strides;
440455
441456 return event_ref;
442457}
0 commit comments