@@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat):
2222 if use_compat and library not in wrapped_libraries :
2323 pytest .raises (ValueError , lambda : array_namespace (array , use_compat = use_compat ))
2424 return
25- namespace = array_api_compat . array_namespace (array , api_version = api_version , use_compat = use_compat )
25+ namespace = array_namespace (array , api_version = api_version , use_compat = use_compat )
2626
2727 if use_compat is False or use_compat is None and library not in wrapped_libraries :
2828 if library == "jax.numpy" and use_compat is None :
@@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat):
4444
4545 if library == "numpy" :
4646 # check that the same namespace is returned for NumPy scalars
47- scalar_namespace = array_api_compat . array_namespace (
47+ scalar_namespace = array_namespace (
4848 xp .float64 (0.0 ), api_version = api_version , use_compat = use_compat
4949 )
5050 assert scalar_namespace == namespace
@@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat):
7575def test_jax_zero_gradient ():
7676 jx = jax .numpy .arange (4 )
7777 jax_zero = jax .vmap (jax .grad (jax .numpy .float32 , allow_int = True ))(jx )
78- assert (array_api_compat .get_namespace (jax_zero ) is
79- array_api_compat .get_namespace (jx ))
78+ assert array_namespace (jax_zero ) is array_namespace (jx )
8079
8180def test_array_namespace_errors ():
8281 pytest .raises (TypeError , lambda : array_namespace ([1 ]))
@@ -91,7 +90,7 @@ def test_array_namespace_errors_torch():
9190 x = np .asarray ([1 , 2 ])
9291 pytest .raises (TypeError , lambda : array_namespace (x , y ))
9392
94- def test_api_version ():
93+ def test_api_version_torch ():
9594 x = torch .asarray ([1 , 2 ])
9695 torch_ = import_ ("torch" , wrapper = True )
9796 assert array_namespace (x , api_version = "2023.12" ) == torch_
@@ -113,7 +112,7 @@ def test_api_version():
113112
114113def test_get_namespace ():
115114 # Backwards compatible wrapper
116- assert array_api_compat .get_namespace is array_api_compat . array_namespace
115+ assert array_api_compat .get_namespace is array_namespace
117116
118117def test_python_scalars ():
119118 a = torch .asarray ([1 , 2 ])
0 commit comments