3434#include " queue_sycl.hpp"
3535
3636namespace mkl_blas = oneapi::mkl::blas;
37+ namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
3738namespace mkl_lapack = oneapi::mkl::lapack;
3839
3940template <typename _DataType, typename _ResultType>
@@ -75,6 +76,82 @@ void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
7576template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
7677class dpnp_dot_c_kernel ;
7778
79+ template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
80+ cl::sycl::event dot (cl::sycl::queue &queue,
81+ _DataType_output *result_out, _DataType_input1 *input1_in, _DataType_input2 *input2_in, size_t input1_strides, size_t input2_strides, size_t size,
82+ const cl::sycl::vector_class<cl::sycl::event> &dependencies = {})
83+ {
84+ (void )dependencies;
85+
86+ cl::sycl::event event;
87+
88+ if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
89+ std::is_same<_DataType_input2, _DataType_input1>::value &&
90+ std::is_same<_DataType_output, _DataType_input1>::value)
91+ {
92+ event = oneapi::mkl::blas::dot (queue,
93+ size,
94+ input1_in,
95+ input1_strides, // input1 stride
96+ input2_in,
97+ input2_strides, // input2 stride
98+ result_out);
99+ }
100+ else
101+ {
102+ #if LIBSYCL_VERSION_GREATER(5, 3, 0)
103+ event = queue.submit ([&](sycl::handler &cgh)
104+ {
105+ cgh.parallel_for (sycl::range<1 >{size},
106+ cl::sycl::reduction (result_out,
107+ std::plus<_DataType_output>(),
108+ cl::sycl::property::reduction::initialize_to_identity{}),
109+ [=](cl::sycl::id<1 > idx, auto & sum)
110+ {
111+ sum += static_cast <_DataType_output>(input1_in[idx * input1_strides]) * static_cast <_DataType_output>(input2_in[idx * input2_strides]);
112+ });
113+ });
114+ // for some reason few such kernels cannot work in parallel
115+ // looks like a bug in level0 because with opencl works fine
116+ // that is why we call wait here
117+ event.wait ();
118+ #else
119+ _DataType_output* local_mem =
120+ reinterpret_cast <_DataType_output*>(dpnp_memory_alloc_c (size * sizeof (_DataType_output)));
121+
122+ // what about reduction??
123+ cl::sycl::range<1 > gws (size);
124+
125+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
126+ const size_t index = global_id[0 ];
127+ local_mem[index] = input1_in[index * input1_strides] * input2_in[index * input2_strides];
128+ };
129+
130+ auto kernel_func = [&](cl::sycl::handler& cgh) {
131+ cgh.parallel_for <class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
132+ gws, kernel_parallel_for_func);
133+ };
134+
135+ event = DPNP_QUEUE.submit (kernel_func);
136+
137+ event.wait ();
138+
139+ auto policy = oneapi::dpl::execution::make_device_policy<
140+ class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
141+
142+ _DataType_output accumulator = 0 ;
143+ accumulator =
144+ std::reduce (policy, local_mem, local_mem + size, _DataType_output (0 ), std::plus<_DataType_output>());
145+ policy.queue ().wait ();
146+
147+ dpnp_memory_memcpy_c (result_out, &accumulator, sizeof (_DataType_output)); // result[0] = accumulator;
148+
149+ free (local_mem, DPNP_QUEUE);
150+ #endif
151+ }
152+ return event;
153+ }
154+
78155template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
79156void dpnp_dot_c (void * result_out,
80157 const size_t result_size,
@@ -92,78 +169,195 @@ void dpnp_dot_c(void* result_out,
92169 const size_t * input2_shape,
93170 const size_t * input2_strides)
94171{
95- (void )input1_shape;
96- (void )input1_ndim;
97- (void )input2_shape;
98- (void )input2_ndim;
99-
100- (void )result_size;
101- (void )result_ndim;
102- (void )result_shape;
103172 (void )result_strides;
104- (void )input1_strides;
105- (void )input2_strides;
106173
107- cl::sycl::event event;
108174 DPNPC_ptr_adapter<_DataType_input1> input1_ptr (input1_in, input1_size);
109175 DPNPC_ptr_adapter<_DataType_input2> input2_ptr (input2_in, input2_size);
110176
111177 _DataType_input1* input1 = input1_ptr.get_ptr ();
112178 _DataType_input2* input2 = input2_ptr.get_ptr ();
113179 _DataType_output* result = reinterpret_cast <_DataType_output*>(result_out);
114180
115- if (!input1_size)
181+ if (!input1_size || !input2_size )
116182 {
183+ _DataType_output val = _DataType_output (0 );
184+ dpnp_initval_c<_DataType_output>(result, &val, result_size);
117185 return ;
118186 }
119187
120- if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
121- std::is_same<_DataType_input2, _DataType_input1>::value &&
122- std::is_same<_DataType_output, _DataType_input1>::value)
188+ // scalar
189+ if ((input1_ndim == 0 ) || (input2_ndim == 0 ))
123190 {
124- event = mkl_blas::dot (DPNP_QUEUE,
125- input1_size,
126- input1,
127- 1 , // input1 stride
128- input2,
129- 1 , // input2 stride
130- result);
191+ // there is no support of strides in multiply function
192+ // so result can be wrong if input array has non-standard (c-contiguous) strides
193+ dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result, \
194+ input1_in,
195+ input1_size,
196+ input1_shape,
197+ input1_ndim,
198+ input2_in,
199+ input2_size,
200+ input2_shape,
201+ input2_ndim,
202+ NULL );
203+ return ;
204+ }
205+
206+ // if both arrays are vectors
207+ if ((input1_ndim == 1 ) && (input2_ndim == 1 ))
208+ {
209+ assert (input1_size == input2_size);
210+ cl::sycl::event event = dot (DPNP_QUEUE, result, input1, input2, input1_strides[0 ], input2_strides[0 ], input1_size);
131211 event.wait ();
212+ return ;
213+ }
214+
215+ // 1D vector
216+ size_t ext_input1_ndim = input1_ndim == 1 ? 2 : input1_ndim;
217+ size_t * ext_input1_shape = new size_t [ext_input1_ndim];
218+ size_t * ext_input1_strides = new size_t [ext_input1_ndim];
219+ if (input1_ndim == 1 )
220+ {
221+ ext_input1_shape[0 ] = 1 ;
222+ ext_input1_shape[1 ] = input1_shape[0 ];
223+ ext_input1_strides[0 ] = 0 ;
224+ ext_input1_strides[1 ] = input1_strides[0 ];
132225 }
133226 else
134227 {
135- _DataType_output* local_mem =
136- reinterpret_cast <_DataType_output*>(dpnp_memory_alloc_c (input1_size * sizeof (_DataType_output)));
228+ for (size_t i = 0 ; i < ext_input1_ndim; ++i)
229+ {
230+ ext_input1_shape[i] = input1_shape[i];
231+ ext_input1_strides[i] = input1_strides[i];
232+ }
233+ }
234+ size_t ext_input2_ndim = input2_ndim == 1 ? 2 : input2_ndim;
235+ size_t * ext_input2_shape = new size_t [ext_input2_ndim];
236+ size_t * ext_input2_strides = new size_t [ext_input2_ndim];
237+ if (input2_ndim == 1 )
238+ {
239+ ext_input2_shape[0 ] = input2_shape[0 ];
240+ ext_input2_shape[1 ] = 1 ;
241+ ext_input2_strides[0 ] = input2_strides[0 ];
242+ ext_input2_strides[1 ] = 0 ;
243+ }
244+ else
245+ {
246+ for (size_t i = 0 ; i < ext_input2_ndim; ++i)
247+ {
248+ ext_input2_shape[i] = input2_shape[i];
249+ ext_input2_strides[i] = input2_strides[i];
250+ }
251+ }
252+ size_t ext_result_ndim = ((input1_ndim == 1 ) || (input2_ndim == 1 )) ? 2 : result_ndim;
253+ size_t * ext_result_shape = new size_t [ext_result_ndim];
254+ if ((input1_ndim == 1 ) || (input2_ndim == 1 ))
255+ {
256+ ext_result_shape[0 ] = ext_input1_shape[0 ];
257+ ext_result_shape[1 ] = ext_input2_shape[1 ];
258+ }
259+ else
260+ {
261+ for (size_t i = 0 ; i < ext_result_ndim; ++i)
262+ {
263+ ext_result_shape[i] = result_shape[i];
264+ }
265+ }
137266
138- // what about reduction??
139- cl::sycl::range<1 > gws (input1_size);
267+ // check if GEMM can be executed (types)
268+ if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
269+ std::is_same<_DataType_input2, _DataType_input1>::value &&
270+ std::is_same<_DataType_output, _DataType_input1>::value)
271+ {
272+ // check if GEMM can be executed (strides)
273+ // TODO: rewrite the condition in general case for ndims > 2
274+ // (looks like there are such another cases)
275+ if ((ext_input1_ndim == 2 && ext_input2_ndim == 2 ) &&
276+ (ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
277+ (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 )
278+ )
279+ {
280+ // there is a difference of behavior with trans and sizes params in previous version of GEMM
281+ // only new version is supported, in case of old version computation goes in common way
282+ #if INTEL_MKL_VERSION >= 20210004
283+ oneapi::mkl::transpose trans1 = ext_input1_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
284+ oneapi::mkl::transpose trans2 = ext_input2_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
285+
286+ const size_t size_m = ext_input1_shape[0 ];
287+ const size_t size_n = ext_input2_shape[1 ];
288+ const size_t size_k = ext_input1_shape[1 ];
289+
290+ const std::int64_t lda = trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
291+ const std::int64_t ldb = trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];;
292+ // defenition of ldc will be another for result with non-standard (c-contiguous) strides
293+ // const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
294+ const std::int64_t ldc = size_n;
295+
296+ cl::sycl::event event = mkl_blas_rm::gemm (DPNP_QUEUE,
297+ trans1,
298+ trans2,
299+ size_m,
300+ size_n,
301+ size_k,
302+ _DataType_output (1 ), // alpha
303+ input1,
304+ lda,
305+ input2,
306+ ldb,
307+ _DataType_output (0 ), // beta
308+ result,
309+ ldc);
310+ event.wait ();
311+ return ;
312+ #endif
313+ }
314+ }
140315
141- auto kernel_parallel_for_func = [=]( cl::sycl::id< 1 > global_id) {
142- const size_t index = global_id[ 0 ] ;
143- local_mem[index] = input1[index] * input2[index] ;
144- } ;
316+ // deprecated? can be replaced with std::vector< cl::sycl::event>
317+ cl::sycl::vector_class<cl::sycl::event> dot_events ;
318+ // std::vector<cl::sycl::event> dot_events ;
319+ dot_events. reserve (result_size) ;
145320
146- auto kernel_func = [&](cl::sycl::handler& cgh) {
147- cgh.parallel_for <class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
148- gws, kernel_parallel_for_func);
149- };
321+ size_t dot_st1 = ext_input1_strides[ext_input1_ndim - 1 ];
322+ size_t dot_st2 = ext_input2_strides[ext_input2_ndim - 2 ];
323+ size_t dot_size = ext_input1_shape[ext_input1_ndim - 1 ];
150324
151- event = DPNP_QUEUE.submit (kernel_func);
325+ size_t * res_coords = new size_t [ext_result_ndim];
326+ size_t * result_offsets = new size_t [ext_result_ndim];
327+ get_shape_offsets_inkernel (ext_result_shape, ext_result_ndim, result_offsets);
152328
153- event.wait ();
329+ for (size_t i = 0 ; i < result_size; ++i)
330+ {
331+ get_xyz_by_id (i, ext_result_ndim, result_offsets, res_coords);
154332
155- auto policy = oneapi::dpl::execution::make_device_policy<
156- class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
333+ _DataType_output* dot_res = result + i;
157334
158- _DataType_output accumulator = 0 ;
159- accumulator =
160- std::reduce (policy, local_mem, local_mem + input1_size, _DataType_output (0 ), std::plus<_DataType_output>());
161- policy.queue ().wait ();
335+ _DataType_input1* dot_in1 = input1;
336+ for (size_t j = 0 ; j < ext_input1_ndim - 1 ; ++j)
337+ {
338+ dot_in1 = dot_in1 + res_coords[j] * ext_input1_strides[j];
339+ }
162340
163- dpnp_memory_memcpy_c (result, &accumulator, sizeof (_DataType_output)); // result[0] = accumulator;
341+ _DataType_input2* dot_in2 = input2;
342+ for (size_t j = 0 ; j < ext_input2_ndim - 2 ; ++j)
343+ {
344+ dot_in2 = dot_in2 + res_coords[ext_input1_ndim - 1 + j] * ext_input2_strides[j];
345+ }
346+ dot_in2 = dot_in2 + res_coords[ext_input1_ndim + ext_input2_ndim - 3 ] * ext_input2_strides[ext_input2_ndim - 1 ];
164347
165- free (local_mem, DPNP_QUEUE );
348+ dot_events. push_back ( dot (DPNP_QUEUE, dot_res, dot_in1, dot_in2, dot_st1, dot_st2, dot_size) );
166349 }
350+
351+ sycl::event::wait (dot_events);
352+
353+ delete[] res_coords;
354+ delete[] result_offsets;
355+ delete[] ext_input1_shape;
356+ delete[] ext_input1_strides;
357+ delete[] ext_input2_shape;
358+ delete[] ext_input2_strides;
359+ delete[] ext_result_shape;
360+
167361}
168362
169363template <typename _DataType, typename _ResultType>
0 commit comments