@@ -1560,17 +1560,31 @@ def test_take_along_axis():
15601560
15611561
15621562def test_take_along_axis_validation ():
1563+ # type check on the first argument
15631564 with pytest .raises (TypeError ):
15641565 dpt .take_along_axis (tuple (), list ())
15651566 get_queue_or_skip ()
1566- x = dpt .ones (10 )
1567+ n1 , n2 = 2 , 5
1568+ x = dpt .ones (n1 * n2 )
1569+ # type check on the second argument
15671570 with pytest .raises (TypeError ):
15681571 dpt .take_along_axis (x , list ())
1569- ind_dt = dpt .__array_namespace_info__ ().default_dtypes (
1570- device = x .sycl_device
1571- )["indexing" ]
1572+ x_dev = x .sycl_device
1573+ info_ = dpt .__array_namespace_info__ ()
1574+ def_dtypes = info_ .default_dtypes (device = x_dev )
1575+ ind_dt = def_dtypes ["indexing" ]
15721576 ind = dpt .zeros (1 , dtype = ind_dt )
1577+ # axis valudation
15731578 with pytest .raises (ValueError ):
15741579 dpt .take_along_axis (x , ind , axis = 1 )
1580+ # mode validation
15751581 with pytest .raises (ValueError ):
15761582 dpt .take_along_axis (x , ind , axis = 0 , mode = "invalid" )
1583+ # same array-ranks validation
1584+ with pytest .raises (ValueError ):
1585+ dpt .take_along_axis (dpt .reshape (x , (n1 , n2 )), ind )
1586+ # check compute-follows-data
1587+ q2 = dpctl .SyclQueue (x_dev , property = "enable_profiling" )
1588+ ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1589+ with pytest .raises (ExecutionPlacementError ):
1590+ dpt .take_along_axis (x , ind2 )
0 commit comments