@@ -169,12 +169,16 @@ void dpnp_fft_fft_sycl_c(const void* array1_in,
169169template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
170170void dpnp_fft_fft_mathlib_compute_c (const void * array1_in,
171171 void * result1,
172+ const shape_elem_type* input_shape,
172173 const size_t shape_size,
173174 const size_t result_size,
174175 _Descriptor_type& desc,
175176 const size_t norm)
176177{
177- sycl::event event;
178+ if (!shape_size)
179+ {
180+ return ;
181+ }
178182
179183 DPNPC_ptr_adapter<_DataType_input> input1_ptr (array1_in, result_size);
180184 DPNPC_ptr_adapter<_DataType_output> result_ptr (result1, result_size);
@@ -187,9 +191,19 @@ void dpnp_fft_fft_mathlib_compute_c(const void* array1_in,
187191 desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
188192 desc.commit (DPNP_QUEUE);
189193
190- event = mkl_dft::compute_forward (desc, array_1, result);
194+ const size_t n_iter =
195+ std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
191196
192- event.wait ();
197+ const size_t shift = input_shape[shape_size - 1 ];
198+
199+ std::vector<sycl::event> fft_events;
200+ fft_events.reserve (n_iter);
201+
202+ for (size_t i = 0 ; i < n_iter; ++i) {
203+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
204+ }
205+
206+ sycl::event::wait (fft_events);
193207
194208 return ;
195209}
@@ -207,39 +221,24 @@ void dpnp_fft_fft_mathlib_c(const void* array1_in,
207221 {
208222 return ;
209223 }
210- std::vector<std::int64_t > dimensions (input_shape, input_shape + shape_size);
224+ // will be used with strides
225+ // std::vector<std::int64_t> dimensions(input_shape, input_shape + shape_size);
211226
212227 if constexpr (std::is_same<_DataType_input, std::complex <double >>::value &&
213228 std::is_same<_DataType_output, std::complex <double >>::value)
214229 {
215- if (shape_size == 1 )
216- {
217- desc_dp_cmplx_t desc (result_size);
218- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
219- array1_in, result1, shape_size, result_size, desc, norm);
220- }
221- else
222- {
223- desc_dp_cmplx_t desc (dimensions);
224- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
225- array1_in, result1, shape_size, result_size, desc, norm);
226- }
230+ desc_dp_cmplx_t desc (input_shape[shape_size - 1 ]);
231+
232+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
233+ array1_in, result1, input_shape, shape_size, result_size, desc, norm);
227234 }
228235 else if (std::is_same<_DataType_input, std::complex <float >>::value &&
229236 std::is_same<_DataType_output, std::complex <float >>::value)
230237 {
231- if (shape_size == 1 )
232- {
233- desc_sp_cmplx_t desc (result_size);
234- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
235- array1_in, result1, shape_size, result_size, desc, norm);
236- }
237- else
238- {
239- desc_sp_cmplx_t desc (dimensions);
240- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
241- array1_in, result1, shape_size, result_size, desc, norm);
242- }
238+ desc_sp_cmplx_t desc (input_shape[shape_size - 1 ]);
239+
240+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
241+ array1_in, result1, input_shape, shape_size, result_size, desc, norm);
243242 }
244243 return ;
245244}
@@ -270,11 +269,10 @@ void dpnp_fft_fft_c(const void* array1_in,
270269 return ;
271270 }
272271
273- if ((( std::is_same<_DataType_input, std::complex <double >>::value &&
272+ if ((std::is_same<_DataType_input, std::complex <double >>::value &&
274273 std::is_same<_DataType_output, std::complex <double >>::value) ||
275274 (std::is_same<_DataType_input, std::complex <float >>::value &&
276- std::is_same<_DataType_output, std::complex <float >>::value)) &&
277- (shape_size <= 3 ))
275+ std::is_same<_DataType_output, std::complex <float >>::value))
278276 {
279277 dpnp_fft_fft_mathlib_c<_DataType_input, _DataType_output>(
280278 array1_in, result1, input_shape, shape_size, result_size, norm);
0 commit comments