3333
3434namespace mkl_dft = oneapi::mkl::dft;
3535
36+ typedef mkl_dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX> desc_dp_cmplx_t ;
37+ typedef mkl_dft::descriptor<mkl_dft::precision::SINGLE, mkl_dft::domain::COMPLEX> desc_sp_cmplx_t ;
38+ typedef mkl_dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::REAL> desc_dp_real_t ;
39+ typedef mkl_dft::descriptor<mkl_dft::precision::SINGLE, mkl_dft::domain::REAL> desc_sp_real_t ;
40+
3641#ifdef _WIN32
3742#ifndef M_PI // Windows compatibility
3843#define M_PI 3.14159265358979323846
@@ -43,23 +48,24 @@ template <typename _KernelNameSpecialization1, typename _KernelNameSpecializatio
4348class dpnp_fft_fft_c_kernel ;
4449
4550template <typename _DataType_input, typename _DataType_output>
46- void dpnp_fft_fft_c (const void * array1_in,
47- void * result1,
48- const long * input_shape,
49- const long * output_shape,
50- size_t shape_size,
51- long axis,
52- long input_boundarie,
53- size_t inverse)
51+ void dpnp_fft_fft_sycl_c (const void * array1_in,
52+ void * result1,
53+ const long * input_shape,
54+ const long * output_shape,
55+ size_t shape_size,
56+ const size_t result_size,
57+ const size_t input_size,
58+ long axis,
59+ long input_boundarie,
60+ size_t inverse)
5461{
55- const size_t input_size = std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<size_t >());
56- const size_t result_size = std::accumulate (output_shape, output_shape + shape_size, 1 , std::multiplies<size_t >());
5762 if (!(input_size && result_size && shape_size))
5863 {
5964 return ;
6065 }
6166
6267 cl::sycl::event event;
68+
6369 const double kernel_pi = inverse ? -M_PI : M_PI;
6470
6571 DPNPC_ptr_adapter<_DataType_input> input1_ptr (array1_in, input_size);
@@ -148,21 +154,139 @@ void dpnp_fft_fft_c(const void* array1_in,
148154 };
149155
150156 event = DPNP_QUEUE.submit (kernel_func);
157+ event.wait ();
158+
159+ dpnp_memory_free_c (input_shape_offsets);
160+ dpnp_memory_free_c (output_shape_offsets);
161+ dpnp_memory_free_c (axis_iterator);
162+
163+ return ;
164+ }
151165
152- #if 0 // keep this code
153- oneapi::mkl::dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX> desc(result_size);
154- desc.set_value(mkl_dft::config_param::FORWARD_SCALE, static_cast<double>(result_size));
155- desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); // enum value from math library C interface
166+ template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
167+ void dpnp_fft_fft_mathlib_compute_c (const void * array1_in,
168+ void * result1,
169+ const size_t shape_size,
170+ const size_t result_size,
171+ _Descriptor_type& desc,
172+ const size_t norm)
173+ {
174+ cl::sycl::event event;
175+
176+ DPNPC_ptr_adapter<_DataType_input> input1_ptr (array1_in, result_size);
177+ DPNPC_ptr_adapter<_DataType_output> result_ptr (result1, result_size);
178+ _DataType_input* array_1 = input1_ptr.get_ptr ();
179+ _DataType_output* result = result_ptr.get_ptr ();
180+
181+ desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, (1.0 / result_size));
182+ // enum value from math library C interface
183+ // instead of mkl_dft::config_value::NOT_INPLACE
184+ desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
156185 desc.commit (DPNP_QUEUE);
157186
158187 event = mkl_dft::compute_forward (desc, array_1, result);
159- #endif
160188
161189 event.wait ();
162190
163- dpnp_memory_free_c (input_shape_offsets);
164- dpnp_memory_free_c (output_shape_offsets);
165- dpnp_memory_free_c (axis_iterator);
191+ return ;
192+ }
193+
194+ // norm: backward - 0, forward is 1
195+ template <typename _DataType_input, typename _DataType_output>
196+ void dpnp_fft_fft_mathlib_c (const void * array1_in,
197+ void * result1,
198+ const long * input_shape,
199+ const size_t shape_size,
200+ const size_t result_size,
201+ const size_t norm)
202+ {
203+ if (!shape_size || !result_size || !array1_in || !result1)
204+ {
205+ return ;
206+ }
207+ std::vector<std::int64_t > dimensions (input_shape, input_shape + shape_size);
208+
209+ if constexpr (std::is_same<_DataType_input, std::complex <double >>::value &&
210+ std::is_same<_DataType_output, std::complex <double >>::value)
211+ {
212+ if (shape_size == 1 )
213+ {
214+ desc_dp_cmplx_t desc (result_size);
215+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
216+ array1_in, result1, shape_size, result_size, desc, norm);
217+ }
218+ else
219+ {
220+ desc_dp_cmplx_t desc (dimensions);
221+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
222+ array1_in, result1, shape_size, result_size, desc, norm);
223+ }
224+ }
225+ else if (std::is_same<_DataType_input, std::complex <float >>::value &&
226+ std::is_same<_DataType_output, std::complex <float >>::value)
227+ {
228+ if (shape_size == 1 )
229+ {
230+ desc_sp_cmplx_t desc (result_size);
231+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
232+ array1_in, result1, shape_size, result_size, desc, norm);
233+ }
234+ else
235+ {
236+ desc_sp_cmplx_t desc (dimensions);
237+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
238+ array1_in, result1, shape_size, result_size, desc, norm);
239+ }
240+ }
241+ return ;
242+ }
243+
244+ template <typename _DataType_input, typename _DataType_output>
245+ void dpnp_fft_fft_c (const void * array1_in,
246+ void * result1,
247+ const long * input_shape,
248+ const long * output_shape,
249+ size_t shape_size,
250+ long axis,
251+ long input_boundarie,
252+ size_t inverse,
253+ const size_t norm)
254+ {
255+ if (!shape_size)
256+ {
257+ return ;
258+ }
259+
260+ const size_t result_size = std::accumulate (output_shape, output_shape + shape_size, 1 , std::multiplies<size_t >());
261+ const size_t input_size = std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<size_t >());
262+
263+ if (!input_size || !result_size || !array1_in || !result1)
264+ {
265+ return ;
266+ }
267+
268+ if (((std::is_same<_DataType_input, std::complex <double >>::value &&
269+ std::is_same<_DataType_output, std::complex <double >>::value) ||
270+ (std::is_same<_DataType_input, std::complex <float >>::value &&
271+ std::is_same<_DataType_output, std::complex <float >>::value)) &&
272+ (shape_size <= 3 ))
273+ {
274+ dpnp_fft_fft_mathlib_c<_DataType_input, _DataType_output>(
275+ array1_in, result1, input_shape, shape_size, result_size, norm);
276+ }
277+ else
278+ {
279+ dpnp_fft_fft_sycl_c<_DataType_input, _DataType_output>(array1_in,
280+ result1,
281+ input_shape,
282+ output_shape,
283+ shape_size,
284+ result_size,
285+ input_size,
286+ axis,
287+ input_boundarie,
288+ inverse);
289+ }
166290
167291 return ;
168292}
@@ -173,12 +297,12 @@ void func_map_init_fft_func(func_map_t& fmap)
173297 (void *)dpnp_fft_fft_c<int , std::complex <double >>};
174298 fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_LNG][eft_LNG] = {eft_C128,
175299 (void *)dpnp_fft_fft_c<long , std::complex <double >>};
176- fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_FLT][eft_FLT] = {eft_C128 ,
177- (void *)dpnp_fft_fft_c<float , std::complex <double >>};
300+ fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_FLT][eft_FLT] = {eft_C64 ,
301+ (void *)dpnp_fft_fft_c<float , std::complex <float >>};
178302 fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_DBL][eft_DBL] = {eft_C128,
179303 (void *)dpnp_fft_fft_c<double , std::complex <double >>};
180304 fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_C64][eft_C64] = {
181- eft_C128 , (void *)dpnp_fft_fft_c<std::complex <float >, std::complex <double >>};
305+ eft_C64 , (void *)dpnp_fft_fft_c<std::complex <float >, std::complex <float >>};
182306 fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_C128][eft_C128] = {
183307 eft_C128, (void *)dpnp_fft_fft_c<std::complex <double >, std::complex <double >>};
184308 return ;
0 commit comments