@@ -1592,3 +1592,94 @@ def test_take_along_axis_validation():
15921592 ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
15931593 with pytest .raises (ExecutionPlacementError ):
15941594 dpt .take_along_axis (x , ind2 )
1595+
1596+
1597+ def check__extract_impl_validation (fn ):
1598+ x = dpt .ones (10 )
1599+ ind = dpt .ones (10 , dtype = "?" )
1600+ with pytest .raises (TypeError ):
1601+ fn (list (), ind )
1602+ with pytest .raises (TypeError ):
1603+ fn (x , list ())
1604+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1605+ ind2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1606+ with pytest .raises (ExecutionPlacementError ):
1607+ fn (x , ind2 )
1608+ with pytest .raises (ValueError ):
1609+ fn (x , ind , 1 )
1610+
1611+
1612+ def check__nonzero_impl_validation (fn ):
1613+ with pytest .raises (TypeError ):
1614+ fn (list ())
1615+
1616+
1617+ def check__take_multi_index (fn ):
1618+ x = dpt .ones (10 )
1619+ x_dev = x .sycl_device
1620+ info_ = dpt .__array_namespace_info__ ()
1621+ def_dtypes = info_ .default_dtypes (device = x_dev )
1622+ ind_dt = def_dtypes ["indexing" ]
1623+ ind = dpt .arange (10 , dtype = ind_dt )
1624+ with pytest .raises (TypeError ):
1625+ fn (list (), tuple (), 1 )
1626+ with pytest .raises (ValueError ):
1627+ fn (x , (ind ,), 0 , mode = 2 )
1628+ with pytest .raises (ValueError ):
1629+ fn (x , (None ,), 1 )
1630+ with pytest .raises (IndexError ):
1631+ fn (x , (x ,), 1 )
1632+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1633+ ind2 = dpt .arange (10 , dtype = ind_dt , sycl_queue = q2 )
1634+ with pytest .raises (ExecutionPlacementError ):
1635+ fn (x , (ind2 ,), 0 )
1636+ m = dpt .ones ((10 , 10 ))
1637+ ind_1 = dpt .arange (10 , dtype = "i8" )
1638+ ind_2 = dpt .arange (10 , dtype = "u8" )
1639+ with pytest .raises (ValueError ):
1640+ fn (m , (ind_1 , ind_2 ), 0 )
1641+
1642+
1643+ def check__place_impl_validation (fn ):
1644+ with pytest .raises (TypeError ):
1645+ fn (list (), list (), list ())
1646+ x = dpt .ones (10 )
1647+ with pytest .raises (TypeError ):
1648+ fn (x , list (), list ())
1649+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1650+ mask2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1651+ with pytest .raises (ExecutionPlacementError ):
1652+ fn (x , mask2 , 1 )
1653+ x2 = dpt .ones ((5 , 5 ))
1654+ mask2 = dpt .ones ((5 , 5 ), dtype = "?" )
1655+ with pytest .raises (ValueError ):
1656+ fn (x2 , mask2 , x2 , axis = 1 )
1657+
1658+
1659+ def check__put_multi_index_validation (fn ):
1660+ with pytest .raises (TypeError ):
1661+ fn (list (), list (), 0 , list ())
1662+ x = dpt .ones (10 )
1663+ inds = dpt .arange (10 , dtype = "i8" )
1664+ vals = dpt .zeros (10 )
1665+ # test inds which is not a tuple/list
1666+ fn (x , inds , 0 , vals )
1667+ x2 = dpt .ones ((5 , 5 ))
1668+ ind1 = dpt .arange (5 , dtype = "i8" )
1669+ ind2 = dpt .arange (5 , dtype = "u8" )
1670+ with pytest .raises (ValueError ):
1671+ fn (x2 , (ind1 , ind2 ), 0 , x2 )
1672+ with pytest .raises (TypeError ):
1673+ fn (x2 , (ind1 , list ()), 0 , x2 )
1674+
1675+
1676+ def test__copy_utils ():
1677+ import dpctl .tensor ._copy_utils as cu
1678+
1679+ get_queue_or_skip ()
1680+
1681+ check__extract_impl_validation (cu ._extract_impl )
1682+ check__nonzero_impl_validation (cu ._nonzero_impl )
1683+ check__take_multi_index (cu ._take_multi_index )
1684+ check__place_impl_validation (cu ._place_impl )
1685+ check__put_multi_index_validation (cu ._put_multi_index )
0 commit comments