3939
4040
4141def test_where_basic ():
42- get_queue_or_skip
42+ get_queue_or_skip ()
4343
4444 cond = dpt .asarray (
4545 [
@@ -58,27 +58,87 @@ def test_where_basic():
5858 assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
5959
6060
61+ def _dtype_all_close (x1 , x2 ):
62+ if np .issubdtype (x2 .dtype , np .floating ) or np .issubdtype (
63+ x2 .dtype , np .complexfloating
64+ ):
65+ x2_dtype = x2 .dtype
66+ return np .allclose (
67+ x1 , x2 , atol = np .finfo (x2_dtype ).eps , rtol = np .finfo (x2_dtype ).eps
68+ )
69+ else :
70+ return np .allclose (x1 , x2 )
71+
72+
6173@pytest .mark .parametrize ("dt1" , _all_dtypes )
6274@pytest .mark .parametrize ("dt2" , _all_dtypes )
6375def test_where_all_dtypes (dt1 , dt2 ):
6476 q = get_queue_or_skip ()
6577 skip_if_dtype_not_supported (dt1 , q )
6678 skip_if_dtype_not_supported (dt2 , q )
6779
68- cond_np = np .arange (5 ) > 2
69- x1_np = np .asarray (2 , dtype = dt1 )
70- x2_np = np .asarray (3 , dtype = dt2 )
71-
72- cond = dpt .asarray (cond_np , sycl_queue = q )
73- x1 = dpt .asarray (x1_np , sycl_queue = q )
74- x2 = dpt .asarray (x2_np , sycl_queue = q )
80+ cond = dpt .asarray ([False , False , False , True , True ], sycl_queue = q )
81+ x1 = dpt .asarray (2 , sycl_queue = q )
82+ x2 = dpt .asarray (3 , sycl_queue = q )
7583
7684 res = dpt .where (cond , x1 , x2 )
77- res_np = np .where ( cond_np , x1_np , x2_np )
85+ res_check = np .asarray ([ 3 , 3 , 3 , 2 , 2 ], dtype = res . dtype )
7886
79- if res .dtype != res_np .dtype :
80- assert res .dtype .kind == res_np .dtype .kind
81- assert_array_equal (dpt .asnumpy (res ).astype (res_np .dtype ), res_np )
87+ dev = q .sycl_device
8288
83- else :
84- assert_array_equal (dpt .asnumpy (res ), res_np )
89+ if not dev .has_aspect_fp16 or not dev .has_aspect_fp64 :
90+ assert res .dtype .kind == dpt .result_type (x1 .dtype , x2 .dtype ).kind
91+
92+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
93+
94+
95+ def test_where_empty ():
96+ # check that numpy returns same results when
97+ # handling empty arrays
98+ get_queue_or_skip ()
99+
100+ empty = dpt .empty (0 )
101+ m = dpt .asarray (True )
102+ x1 = dpt .asarray (1 )
103+ x2 = dpt .asarray (2 )
104+ res = dpt .where (empty , x1 , x2 )
105+
106+ empty_np = np .empty (0 )
107+ m_np = dpt .asnumpy (m )
108+ x1_np = dpt .asnumpy (x1 )
109+ x2_np = dpt .asnumpy (x2 )
110+ res_np = np .where (empty_np , x1_np , x2_np )
111+
112+ assert_array_equal (dpt .asnumpy (res ), res_np )
113+
114+ res = dpt .where (m , empty , x2 )
115+ res_np = np .where (m_np , empty_np , x2_np )
116+
117+ assert_array_equal (dpt .asnumpy (res ), res_np )
118+
119+
120+ @pytest .mark .parametrize ("dt" , _all_dtypes )
121+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
122+ def test_where_contiguous (dt , order ):
123+ q = get_queue_or_skip ()
124+ skip_if_dtype_not_supported (dt , q )
125+
126+ cond = dpt .asarray (
127+ [
128+ [[True , False , False ], [False , True , True ]],
129+ [[False , True , False ], [True , False , True ]],
130+ [[False , False , True ], [False , False , True ]],
131+ [[False , False , False ], [True , False , True ]],
132+ [[True , True , True ], [True , False , True ]],
133+ ],
134+ sycl_queue = q ,
135+ order = order ,
136+ )
137+
138+ x1 = dpt .full (cond .shape , 2 , dtype = dt , order = order , sycl_queue = q )
139+ x2 = dpt .full (cond .shape , 3 , dtype = dt , order = order , sycl_queue = q )
140+
141+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
142+ res = dpt .where (cond , x1 , x2 )
143+
144+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
0 commit comments