4545
4646from dpnp .dpnp_utils import *
4747from dpnp .fft .dpnp_algo_fft import *
48+ from enum import Enum
4849
4950
5051__all__ = [
6970]
7071
7172
73+ class Norm (Enum ):
74+ backward = 0
75+ forward = 1
76+ ortho = 2
77+
78+ def get_validated_norm (norm ):
79+ if norm is None or norm == "backward" :
80+ return Norm .backward
81+ if norm == "forward" :
82+ return Norm .forward
83+ if norm == "ortho" :
84+ return Norm .ortho
85+ raise ValueError ("Unknown norm value." )
86+
87+
7288def fft (x1 , n = None , axis = - 1 , norm = None ):
7389 """
7490 Compute the one-dimensional discrete Fourier Transform.
@@ -86,10 +102,8 @@ def fft(x1, n=None, axis=-1, norm=None):
86102
87103 x1_desc = dpnp .get_dpnp_descriptor (x1 )
88104 if x1_desc :
89- # if norm is None or norm is 'backward':
90- # norm_val = 0
91- # else:
92- # norm_val = 1
105+ norm_ = get_validated_norm (norm )
106+
93107 if axis is None :
94108 axis_param = - 1 # the most right dimension (default value)
95109 else :
@@ -108,9 +122,11 @@ def fft(x1, n=None, axis=-1, norm=None):
108122 pass
109123 elif axis != - 1 :
110124 pass
125+ elif x1_desc .dtype not in (numpy .complex128 , numpy .complex64 ):
126+ pass
111127 else :
112128 output_boundarie = input_boundarie
113- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , 0 ).get_pyobj ()
129+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
114130 return call_origin (numpy .fft .fft , x1 , n , axis , norm )
115131
116132
@@ -219,6 +235,9 @@ def fftshift(x1, axes=None):
219235
220236 x1_desc = dpnp .get_dpnp_descriptor (x1 )
221237 if x1_desc and 0 :
238+
239+ norm_ = Norm .backward
240+
222241 if axis is None :
223242 axis_param = - 1 # the most right dimension (default value)
224243 else :
@@ -227,7 +246,7 @@ def fftshift(x1, axes=None):
227246 if x1_desc .size < 1 :
228247 pass # let fallback to handle exception
229248 else :
230- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False ).get_pyobj ()
249+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
231250
232251 return call_origin (numpy .fft .fftshift , x1 , axes )
233252
@@ -248,6 +267,8 @@ def hfft(x1, n=None, axis=-1, norm=None):
248267
249268 x1_desc = dpnp .get_dpnp_descriptor (x1 )
250269 if x1_desc and 0 :
270+ norm_ = get_validated_norm (norm )
271+
251272 if axis is None :
252273 axis_param = - 1 # the most right dimension (default value)
253274 else :
@@ -267,7 +288,7 @@ def hfft(x1, n=None, axis=-1, norm=None):
267288 else :
268289 output_boundarie = input_boundarie
269290
270- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False ).get_pyobj ()
291+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
271292
272293 return call_origin (numpy .fft .hfft , x1 , n , axis , norm )
273294
@@ -287,7 +308,9 @@ def ifft(x1, n=None, axis=-1, norm=None):
287308 """
288309
289310 x1_desc = dpnp .get_dpnp_descriptor (x1 )
290- if x1_desc :
311+ if x1_desc and 0 :
312+ norm_ = get_validated_norm (norm )
313+
291314 if axis is None :
292315 axis_param = - 1 # the most right dimension (default value)
293316 else :
@@ -307,7 +330,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
307330 else :
308331 output_boundarie = input_boundarie
309332
310- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , True ).get_pyobj ()
333+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , True , norm_ . value ).get_pyobj ()
311334
312335 return call_origin (numpy .fft .ifft , x1 , n , axis , norm )
313336
@@ -354,6 +377,9 @@ def ifftshift(x1, axes=None):
354377
355378 x1_desc = dpnp .get_dpnp_descriptor (x1 )
356379 if x1_desc and 0 :
380+
381+ norm_ = Norm .backward
382+
357383 if axis is None :
358384 axis_param = - 1 # the most right dimension (default value)
359385 else :
@@ -362,7 +388,7 @@ def ifftshift(x1, axes=None):
362388 if x1_desc .size < 1 :
363389 pass # let fallback to handle exception
364390 else :
365- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False ).get_pyobj ()
391+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
366392
367393 return call_origin (numpy .fft .ifftshift , x1 , axes )
368394
@@ -384,7 +410,7 @@ def ifftn(x1, s=None, axes=None, norm=None):
384410 """
385411
386412 x1_desc = dpnp .get_dpnp_descriptor (x1 )
387- if x1_desc :
413+ if x1_desc and 0 :
388414 if s is None :
389415 boundaries = tuple ([x1_desc .shape [i ] for i in range (x1_desc .ndim )])
390416 else :
@@ -432,6 +458,8 @@ def ihfft(x1, n=None, axis=-1, norm=None):
432458
433459 x1_desc = dpnp .get_dpnp_descriptor (x1 )
434460 if x1_desc and 0 :
461+ norm_ = get_validated_norm (norm )
462+
435463 if axis is None :
436464 axis_param = - 1 # the most right dimension (default value)
437465 else :
@@ -451,7 +479,7 @@ def ihfft(x1, n=None, axis=-1, norm=None):
451479 else :
452480 output_boundarie = input_boundarie
453481
454- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False ).get_pyobj ()
482+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
455483
456484 return call_origin (numpy .fft .ihfft , x1 , n , axis , norm )
457485
@@ -472,6 +500,8 @@ def irfft(x1, n=None, axis=-1, norm=None):
472500
473501 x1_desc = dpnp .get_dpnp_descriptor (x1 )
474502 if x1_desc and 0 :
503+ norm_ = get_validated_norm (norm )
504+
475505 if axis is None :
476506 axis_param = - 1 # the most right dimension (default value)
477507 else :
@@ -491,7 +521,7 @@ def irfft(x1, n=None, axis=-1, norm=None):
491521 else :
492522 output_boundarie = 2 * (input_boundarie - 1 )
493523
494- result = dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , True ).get_pyobj ()
524+ result = dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , True , norm_ . value ).get_pyobj ()
495525 # TODO tmp = utils.create_output_array(result_shape, result_c_type, out)
496526 # tmp = dparray(result.shape, dtype=dpnp.float64)
497527 # for it in range(tmp.size):
@@ -592,6 +622,8 @@ def rfft(x1, n=None, axis=-1, norm=None):
592622
593623 x1_desc = dpnp .get_dpnp_descriptor (x1 )
594624 if x1_desc :
625+ norm_ = get_validated_norm (norm )
626+
595627 if axis is None :
596628 axis_param = - 1 # the most right dimension (default value)
597629 else :
@@ -608,10 +640,14 @@ def rfft(x1, n=None, axis=-1, norm=None):
608640 pass # let fallback to handle exception
609641 elif norm is not None :
610642 pass
643+ elif x1_desc .ndim > 1 :
644+ pass
645+ elif x1_desc .dtype not in (numpy .complex128 , numpy .complex64 ):
646+ pass
611647 else :
612648 output_boundarie = input_boundarie // 2 + 1 # rfft specific requirenment
613649
614- return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False ).get_pyobj ()
650+ return dpnp_fft (x1_desc , input_boundarie , output_boundarie , axis_param , False , norm_ . value ).get_pyobj ()
615651
616652 return call_origin (numpy .fft .rfft , x1 , n , axis , norm )
617653
@@ -674,7 +710,7 @@ def rfftn(x1, s=None, axes=None, norm=None):
674710 """
675711
676712 x1_desc = dpnp .get_dpnp_descriptor (x1 )
677- if x1_desc :
713+ if x1_desc and 0 :
678714 if s is None :
679715 boundaries = tuple ([x1_desc .shape [i ] for i in range (x1_desc .ndim )])
680716 else :
0 commit comments