@@ -284,12 +284,6 @@ def test_fft(func_name):
284284 'unstack' : lambda : xp .unstack (xp .ones ((3 , 3 )), axis = 0 ),
285285}
286286
287- api_version_2024_12_examples = {
288- 'diff' : lambda : xp .diff (xp .asarray ([0 , 1 , 2 ])),
289- 'nextafter' : lambda : xp .nextafter (xp .asarray (0. ), xp .asarray (1. )),
290- 'reciprocal' : lambda : xp .reciprocal (xp .asarray ([2. ])),
291- }
292-
293287@pytest .mark .parametrize ('func_name' , api_version_2023_12_examples .keys ())
294288def test_api_version_2023_12 (func_name ):
295289 func = api_version_2023_12_examples [func_name ]
@@ -308,6 +302,14 @@ def test_api_version_2023_12(func_name):
308302 set_array_api_strict_flags (api_version = '2022.12' )
309303 pytest .raises (RuntimeError , func )
310304
305+ api_version_2024_12_examples = {
306+ 'diff' : lambda : xp .diff (xp .asarray ([0 , 1 , 2 ])),
307+ 'nextafter' : lambda : xp .nextafter (xp .asarray (0. ), xp .asarray (1. )),
308+ 'reciprocal' : lambda : xp .reciprocal (xp .asarray ([2. ])),
309+ 'take_along_axis' : lambda : xp .take_along_axis (xp .zeros ((2 , 3 )),
310+ xp .zeros ((1 , 4 ), dtype = xp .int64 )),
311+ }
312+
311313@pytest .mark .parametrize ('func_name' , api_version_2024_12_examples .keys ())
312314def test_api_version_2024_12 (func_name ):
313315 func = api_version_2024_12_examples [func_name ]
0 commit comments