@@ -286,18 +286,8 @@ def test_properties(dt):
286286@pytest .mark .parametrize ("shape" , [tuple (), (1 ,), (1 , 1 ), (1 , 1 , 1 )])
287287@pytest .mark .parametrize ("dtype" , ["|b1" , "|u2" , "|f4" , "|i8" ])
288288class TestCopyScalar :
289- def test_copy_bool_scalar_with_func (self , shape , dtype ):
290- try :
291- X = dpt .usm_ndarray (shape , dtype = dtype )
292- except dpctl .SyclDeviceCreationError :
293- pytest .skip ("No SYCL devices available" )
294- Y = np .arange (1 , X .size + 1 , dtype = dtype )
295- X .usm_data .copy_from_host (Y .view ("|u1" ))
296- Y .shape = tuple ()
297- assert bool (X ) == bool (Y )
298-
299- @pytest .mark .parametrize ("func" , [float , int , complex ])
300- def test_copy_numeric_scalar_with_func (self , func , shape , dtype ):
289+ @pytest .mark .parametrize ("func" , [bool , float , int , complex ])
290+ def test_copy_scalar_with_func (self , func , shape , dtype ):
301291 try :
302292 X = dpt .usm_ndarray (shape , dtype = dtype )
303293 except dpctl .SyclDeviceCreationError :
@@ -312,18 +302,10 @@ def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
312302 # 0D arrays are allowed to convert
313303 assert func (X ) == func (Y )
314304
315- def test_copy_bool_scalar_with_method (self , shape , dtype ):
316- try :
317- X = dpt .usm_ndarray (shape , dtype = dtype )
318- except dpctl .SyclDeviceCreationError :
319- pytest .skip ("No SYCL devices available" )
320- Y = np .arange (1 , X .size + 1 , dtype = dtype )
321- X .usm_data .copy_from_host (Y .view ("|u1" ))
322- Y = Y .reshape (())
323- assert getattr (X , "__bool__" )() == getattr (Y , "__bool__" )()
324-
325- @pytest .mark .parametrize ("method" , ["__float__" , "__int__" , "__complex__" ])
326- def test_copy_numeric_scalar_with_method (self , method , shape , dtype ):
305+ @pytest .mark .parametrize (
306+ "method" , ["__bool__" , "__float__" , "__int__" , "__complex__" ]
307+ )
308+ def test_copy_scalar_with_method (self , method , shape , dtype ):
327309 try :
328310 X = dpt .usm_ndarray (shape , dtype = dtype )
329311 except dpctl .SyclDeviceCreationError :
0 commit comments