@@ -1520,28 +1520,56 @@ def test_clip(device):
15201520 assert_sycl_queue_equal (x .sycl_queue , y .sycl_queue )
15211521
15221522
1523- @pytest .mark .parametrize ("func" , ["take" , "take_along_axis" ])
15241523@pytest .mark .parametrize (
15251524 "device" ,
15261525 valid_devices ,
15271526 ids = [device .filter_string for device in valid_devices ],
15281527)
1529- def test_take (func , device ):
1528+ def test_take (device ):
15301529 numpy_data = numpy .arange (5 )
15311530 dpnp_data = dpnp .array (numpy_data , device = device )
15321531
15331532 dpnp_ind = dpnp .array ([0 , 2 , 4 ], device = device )
15341533 np_ind = dpnp_ind .asnumpy ()
15351534
1536- result = getattr ( dpnp , func ) (dpnp_data , dpnp_ind , axis = None )
1537- expected = getattr ( numpy , func ) (numpy_data , np_ind , axis = None )
1535+ result = dpnp . take (dpnp_data , dpnp_ind , axis = None )
1536+ expected = numpy . take (numpy_data , np_ind , axis = None )
15381537 assert_allclose (expected , result )
15391538
15401539 expected_queue = dpnp_data .get_array ().sycl_queue
15411540 result_queue = result .get_array ().sycl_queue
15421541 assert_sycl_queue_equal (result_queue , expected_queue )
15431542
15441543
1544+ @pytest .mark .parametrize (
1545+ "data, ind, axis" ,
1546+ [
1547+ (numpy .arange (6 ), numpy .array ([0 , 2 , 4 ]), None ),
1548+ (
1549+ numpy .arange (6 ).reshape ((2 , 3 )),
1550+ numpy .array ([0 , 1 ]).reshape ((2 , 1 )),
1551+ 1 ,
1552+ ),
1553+ ],
1554+ )
1555+ @pytest .mark .parametrize (
1556+ "device" ,
1557+ valid_devices ,
1558+ ids = [device .filter_string for device in valid_devices ],
1559+ )
1560+ def test_take_along_axis (data , ind , axis , device ):
1561+ dp_data = dpnp .array (data , device = device )
1562+ dp_ind = dpnp .array (ind , device = device )
1563+
1564+ result = dpnp .take_along_axis (dp_data , dp_ind , axis = axis )
1565+ expected = numpy .take_along_axis (data , ind , axis = axis )
1566+ assert_allclose (expected , result )
1567+
1568+ expected_queue = dp_data .get_array ().sycl_queue
1569+ result_queue = result .get_array ().sycl_queue
1570+ assert_sycl_queue_equal (result_queue , expected_queue )
1571+
1572+
15451573@pytest .mark .parametrize (
15461574 "device" ,
15471575 valid_devices ,
0 commit comments