11# Data Parallel Control (dpctl)
22#
3- # Copyright 2020-2022 Intel Corporation
3+ # Copyright 2020-2023 Intel Corporation
44#
55# Licensed under the Apache License, Version 2.0 (the "License");
66# you may not use this file except in compliance with the License.
2020from numpy .testing import assert_array_equal
2121
2222import dpctl .tensor as dpt
23+ from dpctl .tensor ._search_functions import _where_result_type
24+ from dpctl .tensor ._type_utils import _all_data_types
25+ from dpctl .utils import ExecutionPlacementError
2326
2427_all_dtypes = [
28+ "?" ,
2529 "u1" ,
2630 "i1" ,
2731 "u2" ,
3842]
3943
4044
45+ class mock_device :
46+ def __init__ (self , fp16 , fp64 ):
47+ self .has_aspect_fp16 = fp16
48+ self .has_aspect_fp64 = fp64
49+
50+
4151def test_where_basic ():
4252 get_queue_or_skip ()
4353
@@ -54,7 +64,16 @@ def test_where_basic():
5464 out_expected = dpt .asarray (
5565 [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [0 , 0 , 0 ], [1 , 1 , 1 ]]
5666 )
67+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
5768
69+ out = dpt .where (cond , dpt .ones (cond .shape ), dpt .zeros (cond .shape ))
70+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
71+
72+ out = dpt .where (
73+ cond ,
74+ dpt .ones (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
75+ dpt .zeros (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
76+ )
5877 assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
5978
6079
@@ -72,38 +91,98 @@ def _dtype_all_close(x1, x2):
7291
7392@pytest .mark .parametrize ("dt1" , _all_dtypes )
7493@pytest .mark .parametrize ("dt2" , _all_dtypes )
75- def test_where_all_dtypes (dt1 , dt2 ):
94+ @pytest .mark .parametrize ("fp16" , [True , False ])
95+ @pytest .mark .parametrize ("fp64" , [True , False ])
96+ def test_where_result_types (dt1 , dt2 , fp16 , fp64 ):
97+ dev = mock_device (fp16 , fp64 )
98+
99+ dt1 = dpt .dtype (dt1 )
100+ dt2 = dpt .dtype (dt2 )
101+ res_t = _where_result_type (dt1 , dt2 , dev )
102+
103+ if fp16 and fp64 :
104+ assert res_t == dpt .result_type (dt1 , dt2 )
105+ else :
106+ if res_t :
107+ assert res_t .kind == dpt .result_type (dt1 , dt2 ).kind
108+ else :
109+ # some illegal cases are covered above, but
110+ # this guarantees that _where_result_type
111+ # produces None only when one of the dtypes
112+ # is illegal given fp aspects of device
113+ all_dts = _all_data_types (fp16 , fp64 )
114+ assert dt1 not in all_dts or dt2 not in all_dts
115+
116+
117+ @pytest .mark .parametrize ("dt" , _all_dtypes )
118+ def test_where_all_dtypes (dt ):
76119 q = get_queue_or_skip ()
77- skip_if_dtype_not_supported (dt1 , q )
78- skip_if_dtype_not_supported (dt2 , q )
120+ skip_if_dtype_not_supported (dt , q )
79121
80- cond = dpt .asarray ([False , False , False , True , True ], sycl_queue = q )
81- x1 = dpt .asarray (2 , sycl_queue = q )
82- x2 = dpt .asarray (3 , sycl_queue = q )
122+ # mask dtype changes
123+ cond = dpt .asarray ([0 , 1 , 3 , 0 , 10 ], dtype = dt , sycl_queue = q )
124+ x1 = dpt .asarray (0 , dtype = "f" , sycl_queue = q )
125+ x2 = dpt .asarray (1 , dtype = "f" , sycl_queue = q )
126+ res = dpt .where (cond , x1 , x2 )
127+
128+ res_check = np .asarray ([1 , 0 , 0 , 1 , 0 ], dtype = res .dtype )
129+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
83130
131+ # contiguous cases
132+ x1 = dpt .full (cond .shape , 0 , dtype = "f4" , sycl_queue = q )
133+ x2 = dpt .full (cond .shape , 1 , dtype = "f4" , sycl_queue = q )
84134 res = dpt .where (cond , x1 , x2 )
85- res_check = np . asarray ([ 3 , 3 , 3 , 2 , 2 ], dtype = res . dtype )
135+ assert _dtype_all_close ( dpt . asnumpy ( res ), res_check )
86136
87- dev = q .sycl_device
137+ # input array dtype changes
138+ cond = dpt .asarray ([False , True , True , False , True ], sycl_queue = q )
139+ x1 = dpt .asarray (0 , dtype = dt , sycl_queue = q )
140+ x2 = dpt .asarray (1 , dtype = dt , sycl_queue = q )
141+ res = dpt .where (cond , x1 , x2 )
88142
89- if not dev . has_aspect_fp16 or not dev . has_aspect_fp64 :
90- assert res . dtype . kind == dpt .result_type ( x1 . dtype , x2 . dtype ). kind
143+ res_check = np . asarray ([ 1 , 0 , 0 , 1 , 0 ], dtype = res . dtype )
144+ assert _dtype_all_close ( dpt .asnumpy ( res ), res_check )
91145
146+ # contiguous cases
147+ x1 = dpt .full (cond .shape , 0 , dtype = dt , sycl_queue = q )
148+ x2 = dpt .full (cond .shape , 1 , dtype = dt , sycl_queue = q )
149+ res = dpt .where (cond , x1 , x2 )
92150 assert _dtype_all_close (dpt .asnumpy (res ), res_check )
93151
94152
153+ def test_where_nan_inf ():
154+ get_queue_or_skip ()
155+
156+ cond = dpt .asarray ([True , False , True , False ], dtype = "?" )
157+ x1 = dpt .asarray ([np .nan , 2.0 , np .inf , 3.0 ], dtype = "f4" )
158+ x2 = dpt .asarray ([2.0 , np .nan , 3.0 , np .inf ], dtype = "f4" )
159+
160+ cond_np = dpt .asnumpy (cond )
161+ x1_np = dpt .asnumpy (x1 )
162+ x2_np = dpt .asnumpy (x2 )
163+
164+ res = dpt .where (cond , x1 , x2 )
165+ res_np = np .where (cond_np , x1_np , x2_np )
166+
167+ assert np .allclose (dpt .asnumpy (res ), res_np , equal_nan = True )
168+
169+ res = dpt .where (x1 , cond , x2 )
170+ res_np = np .where (x1_np , cond_np , x2_np )
171+ assert _dtype_all_close (dpt .asnumpy (res ), res_np )
172+
173+
95174def test_where_empty ():
96175 # check that numpy returns same results when
97176 # handling empty arrays
98177 get_queue_or_skip ()
99178
100- empty = dpt .empty (0 )
179+ empty = dpt .empty (0 , dtype = "i2" )
101180 m = dpt .asarray (True )
102- x1 = dpt .asarray (1 )
103- x2 = dpt .asarray (2 )
181+ x1 = dpt .asarray (1 , dtype = "i2" )
182+ x2 = dpt .asarray (2 , dtype = "i2" )
104183 res = dpt .where (empty , x1 , x2 )
105184
106- empty_np = np .empty (0 )
185+ empty_np = np .empty (0 , dtype = "i2" )
107186 m_np = dpt .asnumpy (m )
108187 x1_np = dpt .asnumpy (x1 )
109188 x2_np = dpt .asnumpy (x2 )
@@ -116,12 +195,14 @@ def test_where_empty():
116195
117196 assert_array_equal (dpt .asnumpy (res ), res_np )
118197
198+ # check that broadcasting is performed
199+ with pytest .raises (ValueError ):
200+ dpt .where (empty , x1 , dpt .empty ((1 , 2 )))
201+
119202
120- @pytest .mark .parametrize ("dt" , _all_dtypes )
121203@pytest .mark .parametrize ("order" , ["C" , "F" ])
122- def test_where_contiguous (dt , order ):
123- q = get_queue_or_skip ()
124- skip_if_dtype_not_supported (dt , q )
204+ def test_where_contiguous (order ):
205+ get_queue_or_skip ()
125206
126207 cond = dpt .asarray (
127208 [
@@ -131,14 +212,100 @@ def test_where_contiguous(dt, order):
131212 [[False , False , False ], [True , False , True ]],
132213 [[True , True , True ], [True , False , True ]],
133214 ],
134- sycl_queue = q ,
135215 order = order ,
136216 )
137217
138- x1 = dpt .full (cond .shape , 2 , dtype = dt , order = order , sycl_queue = q )
139- x2 = dpt .full (cond .shape , 3 , dtype = dt , order = order , sycl_queue = q )
218+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" , order = order )
219+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" , order = order )
220+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
221+ res = dpt .where (cond , x1 , x2 )
222+
223+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
224+
225+
226+ def test_where_contiguous1D ():
227+ get_queue_or_skip ()
140228
229+ cond = dpt .asarray ([True , False , True , False , False , True ])
230+
231+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" )
232+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" )
141233 expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
142234 res = dpt .where (cond , x1 , x2 )
235+ assert_array_equal (dpt .asnumpy (res ), expected )
143236
237+ # test with complex dtype (branch in kernel)
238+ x1 = dpt .astype (x1 , dpt .complex64 )
239+ x2 = dpt .astype (x2 , dpt .complex64 )
240+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
241+ res = dpt .where (cond , x1 , x2 )
144242 assert _dtype_all_close (dpt .asnumpy (res ), expected )
243+
244+
245+ def test_where_strided ():
246+ get_queue_or_skip ()
247+
248+ s0 , s1 = 4 , 9
249+ cond = dpt .reshape (
250+ dpt .asarray (
251+ [True , False , False , False , True , True , False , True , False ] * s0
252+ ),
253+ (s0 , s1 ),
254+ )[:, ::3 ]
255+
256+ x1 = dpt .reshape (
257+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 2 , dtype = "i4" ),
258+ (cond .shape [0 ], cond .shape [1 ] * 2 ),
259+ )[:, ::2 ]
260+ x2 = dpt .reshape (
261+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 3 , dtype = "i4" ),
262+ (cond .shape [0 ], cond .shape [1 ] * 3 ),
263+ )[:, ::3 ]
264+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
265+ res = dpt .where (cond , x1 , x2 )
266+
267+ assert_array_equal (dpt .asnumpy (res ), expected )
268+
269+ # negative strides
270+ res = dpt .where (cond , dpt .flip (x1 ), x2 )
271+ expected = np .where (
272+ dpt .asnumpy (cond ), np .flip (dpt .asnumpy (x1 )), dpt .asnumpy (x2 )
273+ )
274+ assert_array_equal (dpt .asnumpy (res ), expected )
275+
276+ res = dpt .where (dpt .flip (cond ), x1 , x2 )
277+ expected = np .where (
278+ np .flip (dpt .asnumpy (cond )), dpt .asnumpy (x1 ), dpt .asnumpy (x2 )
279+ )
280+ assert_array_equal (dpt .asnumpy (res ), expected )
281+
282+
283+ def test_where_arg_validation ():
284+ get_queue_or_skip ()
285+
286+ check = dict ()
287+ x1 = dpt .empty ((1 ,), dtype = "i4" )
288+ x2 = dpt .empty ((1 ,), dtype = "i4" )
289+
290+ with pytest .raises (TypeError ):
291+ dpt .where (check , x1 , x2 )
292+ with pytest .raises (TypeError ):
293+ dpt .where (x1 , check , x2 )
294+ with pytest .raises (TypeError ):
295+ dpt .where (x1 , x2 , check )
296+
297+
298+ def test_where_compute_follows_data ():
299+ q1 = get_queue_or_skip ()
300+ q2 = get_queue_or_skip ()
301+ q3 = get_queue_or_skip ()
302+
303+ x1 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 )
304+ x2 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q2 )
305+
306+ with pytest .raises (ExecutionPlacementError ):
307+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 ), x1 , x2 )
308+ with pytest .raises (ExecutionPlacementError ):
309+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q3 ), x1 , x2 )
310+ with pytest .raises (ExecutionPlacementError ):
311+ dpt .where (x1 , x1 , x2 )
0 commit comments