@@ -77,9 +77,14 @@ template <typename _KernelNameSpecialization1, typename _KernelNameSpecializatio
7777class dpnp_dot_c_kernel ;
7878
7979template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
80- cl::sycl::event dot (cl::sycl::queue &queue,
81- _DataType_output *result_out, _DataType_input1 *input1_in, _DataType_input2 *input2_in, size_t input1_strides, size_t input2_strides, size_t size,
82- const cl::sycl::vector_class<cl::sycl::event> &dependencies = {})
80+ cl::sycl::event dot (cl::sycl::queue& queue,
81+ _DataType_output* result_out,
82+ _DataType_input1* input1_in,
83+ _DataType_input2* input2_in,
84+ size_t input1_strides,
85+ size_t input2_strides,
86+ size_t size,
87+ const cl::sycl::vector_class<cl::sycl::event>& dependencies = {})
8388{
8489 (void )dependencies;
8590
@@ -100,16 +105,15 @@ cl::sycl::event dot(cl::sycl::queue &queue,
100105 else
101106 {
102107#if LIBSYCL_VERSION_GREATER(5, 3, 0)
103- event = queue.submit ([&](sycl::handler &cgh)
104- {
108+ event = queue.submit ([&](sycl::handler& cgh) {
105109 cgh.parallel_for (sycl::range<1 >{size},
106- cl::sycl::reduction (result_out,
107- std::plus<_DataType_output>(),
108- cl::sycl::property::reduction::initialize_to_identity{}),
109- [=](cl::sycl::id<1 > idx, auto & sum)
110- {
111- sum += static_cast <_DataType_output>(input1_in[idx * input1_strides]) * static_cast <_DataType_output>(input2_in[idx * input2_strides]);
112- });
110+ cl::sycl::reduction (result_out,
111+ std::plus<_DataType_output>(),
112+ cl::sycl::property::reduction::initialize_to_identity{}),
113+ [=](cl::sycl::id<1 > idx, auto & sum) {
114+ sum += static_cast <_DataType_output>(input1_in[idx * input1_strides]) *
115+ static_cast <_DataType_output>(input2_in[idx * input2_strides]);
116+ });
113117 });
114118 // for some reason few such kernels cannot work in parallel
115119 // looks like a bug in level0 because with opencl works fine
@@ -190,7 +194,7 @@ void dpnp_dot_c(void* result_out,
190194 {
191195 // there is no support of strides in multiply function
192196 // so result can be wrong if input array has non-standard (c-contiguous) strides
193- dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result, \
197+ dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result,
194198 input1_in,
195199 input1_size,
196200 input1_shape,
@@ -207,7 +211,8 @@ void dpnp_dot_c(void* result_out,
207211 if ((input1_ndim == 1 ) && (input2_ndim == 1 ))
208212 {
209213 assert (input1_size == input2_size);
210- cl::sycl::event event = dot (DPNP_QUEUE, result, input1, input2, input1_strides[0 ], input2_strides[0 ], input1_size);
214+ cl::sycl::event event =
215+ dot (DPNP_QUEUE, result, input1, input2, input1_strides[0 ], input2_strides[0 ], input1_size);
211216 event.wait ();
212217 return ;
213218 }
@@ -225,7 +230,7 @@ void dpnp_dot_c(void* result_out,
225230 }
226231 else
227232 {
228- for (size_t i = 0 ; i < ext_input1_ndim; ++i)
233+ for (size_t i = 0 ; i < ext_input1_ndim; ++i)
229234 {
230235 ext_input1_shape[i] = input1_shape[i];
231236 ext_input1_strides[i] = input1_strides[i];
@@ -243,7 +248,7 @@ void dpnp_dot_c(void* result_out,
243248 }
244249 else
245250 {
246- for (size_t i = 0 ; i < ext_input2_ndim; ++i)
251+ for (size_t i = 0 ; i < ext_input2_ndim; ++i)
247252 {
248253 ext_input2_shape[i] = input2_shape[i];
249254 ext_input2_strides[i] = input2_strides[i];
@@ -258,7 +263,7 @@ void dpnp_dot_c(void* result_out,
258263 }
259264 else
260265 {
261- for (size_t i = 0 ; i < ext_result_ndim; ++i)
266+ for (size_t i = 0 ; i < ext_result_ndim; ++i)
262267 {
263268 ext_result_shape[i] = result_shape[i];
264269 }
@@ -274,21 +279,25 @@ void dpnp_dot_c(void* result_out,
274279 // (looks like there are such another cases)
275280 if ((ext_input1_ndim == 2 && ext_input2_ndim == 2 ) &&
276281 (ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
277- (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 )
278- )
282+ (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 ))
279283 {
280284// there is a difference of behavior with trans and sizes params in previous version of GEMM
281285// only new version is supported, in case of old version computation goes in common way
282286#if INTEL_MKL_VERSION >= 20210004
283- oneapi::mkl::transpose trans1 = ext_input1_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
284- oneapi::mkl::transpose trans2 = ext_input2_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
287+ oneapi::mkl::transpose trans1 =
288+ ext_input1_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
289+ oneapi::mkl::transpose trans2 =
290+ ext_input2_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
285291
286292 const size_t size_m = ext_input1_shape[0 ];
287293 const size_t size_n = ext_input2_shape[1 ];
288294 const size_t size_k = ext_input1_shape[1 ];
289295
290- const std::int64_t lda = trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
291- const std::int64_t ldb = trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];;
296+ const std::int64_t lda =
297+ trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
298+ const std::int64_t ldb =
299+ trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];
300+ ;
292301 // defenition of ldc will be another for result with non-standard (c-contiguous) strides
293302 // const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
294303 const std::int64_t ldc = size_n;
@@ -326,20 +335,20 @@ void dpnp_dot_c(void* result_out,
326335 size_t * result_offsets = new size_t [ext_result_ndim];
327336 get_shape_offsets_inkernel (ext_result_shape, ext_result_ndim, result_offsets);
328337
329- for (size_t i = 0 ; i < result_size; ++i)
338+ for (size_t i = 0 ; i < result_size; ++i)
330339 {
331340 get_xyz_by_id (i, ext_result_ndim, result_offsets, res_coords);
332341
333342 _DataType_output* dot_res = result + i;
334343
335344 _DataType_input1* dot_in1 = input1;
336- for (size_t j = 0 ; j < ext_input1_ndim - 1 ; ++j)
345+ for (size_t j = 0 ; j < ext_input1_ndim - 1 ; ++j)
337346 {
338347 dot_in1 = dot_in1 + res_coords[j] * ext_input1_strides[j];
339348 }
340349
341350 _DataType_input2* dot_in2 = input2;
342- for (size_t j = 0 ; j < ext_input2_ndim - 2 ; ++j)
351+ for (size_t j = 0 ; j < ext_input2_ndim - 2 ; ++j)
343352 {
344353 dot_in2 = dot_in2 + res_coords[ext_input1_ndim - 1 + j] * ext_input2_strides[j];
345354 }
@@ -357,7 +366,6 @@ void dpnp_dot_c(void* result_out,
357366 delete[] ext_input2_shape;
358367 delete[] ext_input2_strides;
359368 delete[] ext_result_shape;
360-
361369}
362370
363371template <typename _DataType, typename _ResultType>
0 commit comments