@@ -114,11 +114,11 @@ def test_complex_order(np_call, dpt_call, dtype):
114114 X [..., 0 ::2 ] = np .pi / 6 + 1j * np .pi / 3
115115 X [..., 1 ::2 ] = np .pi / 3 + 1j * np .pi / 6
116116
117- for ord in ["C" , "F" , "A" , "K" ]:
118- for perms in itertools .permutations (range (4 )):
119- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
117+ for perms in itertools .permutations (range (4 )):
118+ U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
119+ expected_Y = np_call (dpt .asnumpy (U ))
120+ for ord in ["C" , "F" , "A" , "K" ]:
120121 Y = dpt_call (U , order = ord )
121- expected_Y = np_call (dpt .asnumpy (U ))
122122 assert np .allclose (dpt .asnumpy (Y ), expected_Y )
123123
124124
@@ -164,35 +164,51 @@ def test_projection(dtype):
164164 "np_call, dpt_call" ,
165165 [(np .real , dpt .real ), (np .imag , dpt .imag ), (np .conj , dpt .conj )],
166166)
167- @pytest .mark .parametrize ("dtype" , ["f4" , "f8" ])
168- @pytest .mark .parametrize ("stride" , [- 1 , 1 , 2 , 4 , 5 ])
169- def test_complex_strided (np_call , dpt_call , dtype , stride ):
167+ @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
168+ def test_complex_strided (np_call , dpt_call , dtype ):
170169 q = get_queue_or_skip ()
171170 skip_if_dtype_not_supported (dtype , q )
172171
173- N = 100
174- rng = np .random .default_rng (42 )
175- x1 = rng .standard_normal (N , dtype )
176- x2 = 1j * rng .standard_normal (N , dtype )
177- x = x1 + x2
178- y = np_call (x [::stride ])
179- z = dpt_call (dpt .asarray (x [::stride ]))
172+ np .random .seed (42 )
173+ strides = np .array ([- 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ])
174+ sizes = [2 , 4 , 6 , 8 , 9 , 24 , 72 ]
175+ tol = 8 * dpt .finfo (dtype ).resolution
180176
181- tol = 8 * dpt .finfo (y .dtype ).resolution
182- assert_allclose (y , dpt .asnumpy (z ), atol = tol , rtol = tol )
177+ low = - 1000.0
178+ high = 1000.0
179+ for ii in sizes :
180+ x1 = np .random .uniform (low = low , high = high , size = ii )
181+ x2 = np .random .uniform (low = low , high = high , size = ii )
182+ Xnp = np .array ([complex (v1 , v2 ) for v1 , v2 in zip (x1 , x2 )], dtype = dtype )
183+ X = dpt .asarray (Xnp )
184+ Ynp = np_call (Xnp )
185+ for jj in strides :
186+ assert_allclose (
187+ dpt .asnumpy (dpt_call (X [::jj ])),
188+ Ynp [::jj ],
189+ atol = tol ,
190+ rtol = tol ,
191+ )
183192
184193
185- @pytest .mark .parametrize ("dtype" , ["f2 " , "f4" , "f8 " ])
194+ @pytest .mark .parametrize ("dtype" , ["c8 " , "c16 " ])
186195def test_complex_special_cases (dtype ):
187196 q = get_queue_or_skip ()
188197 skip_if_dtype_not_supported (dtype , q )
189198
190- x = [np .nan , - np .nan , np .inf , - np .inf ]
191- with np .errstate (all = "ignore" ):
192- Xnp = 1j * np .array (x , dtype = dtype )
193- X = dpt .asarray (Xnp , dtype = Xnp .dtype )
199+ x = [np .nan , - np .nan , np .inf , - np .inf , + 0.0 , - 0.0 ]
200+ xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
201+
202+ Xc_np = np .array (xc , dtype = dtype )
203+ Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
194204
195205 tol = 8 * dpt .finfo (dtype ).resolution
196- assert_allclose (dpt .asnumpy (dpt .real (X )), np .real (Xnp ), atol = tol , rtol = tol )
197- assert_allclose (dpt .asnumpy (dpt .imag (X )), np .imag (Xnp ), atol = tol , rtol = tol )
198- assert_allclose (dpt .asnumpy (dpt .conj (X )), np .conj (Xnp ), atol = tol , rtol = tol )
206+ assert_allclose (
207+ dpt .asnumpy (dpt .real (Xc )), np .real (Xc_np ), atol = tol , rtol = tol
208+ )
209+ assert_allclose (
210+ dpt .asnumpy (dpt .imag (Xc )), np .imag (Xc_np ), atol = tol , rtol = tol
211+ )
212+ assert_allclose (
213+ dpt .asnumpy (dpt .conj (Xc )), np .conj (Xc_np ), atol = tol , rtol = tol
214+ )
0 commit comments