Skip to content

Commit b5279c6

Browse files
lukebaumanncopybara-github
authored andcommitted
Expose _split_by_mesh_axis directly in pw_jax.
This change exposes the `_split_by_mesh_axis` function as `pw_jax.split_by_mesh_axis` instead of exposing the entire `jaxlib._pathways` module. This provides a more focused API. PiperOrigin-RevId: 852364503
1 parent 717fd15 commit b5279c6

2 files changed

Lines changed: 17 additions & 11 deletions

File tree

pathwaysutils/experimental/split_by_mesh_axis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def split_by_mesh_axis(
161161
for x in flat_arrays
162162
]
163163

164-
flat_split_arrays = pw_jax.jaxlib_pathways._split_by_mesh_axis( # pylint: disable=protected-access
164+
flat_split_arrays = pw_jax.split_by_mesh_axis(
165165
arrays=flat_arrays,
166166
sharded_dim_idxs=sharded_dim_idxs,
167167
mesh_axis_sizes=mesh.axis_sizes,

pathwaysutils/jax/__init__.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@
2121
import jax
2222

2323

24-
class _FakeJaxModule:
25-
"""A fake module that raises an ImportError when accessed.
24+
class _FakeJaxFunction:
25+
"""An object that raises an ImportError for __getattr__ and __call__.
2626
27-
This is used to provide a placeholder for JAX modules that are not available
28-
in older versions of JAX, raising a helpful error message if they are
29-
inadvertently used.
27+
This is used to provide a placeholder for JAX functions that are not
28+
available in older versions of JAX, raising a helpful error message if they
29+
are inadvertently used.
3030
"""
3131

3232
def __init__(self, name, version):
3333
self.__name__ = name
3434
self.version = version
3535
self.error_message = (
36-
f"Module {self.__name__} does not exist until JAX {self.version}. "
36+
f"Function {self.__name__} does not exist until JAX {self.version}. "
3737
f"The current version of JAX is {jax.__version__}. "
38-
"Using this modules results in this runtime error."
38+
"Using this function results in this runtime error."
3939
)
4040

4141
def __getattr__(self, name):
@@ -77,14 +77,20 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
7777

7878
try:
7979
# jax>=0.8.0
80-
from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top
80+
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
81+
82+
split_by_mesh_axis = _pathways._split_by_mesh_axis
83+
del _pathways
8184

8285
except ImportError:
8386
# jax<0.8.0
8487

85-
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0")
88+
split_by_mesh_axis = _FakeJaxFunction(
89+
"jax.jaxlib._pathways._split_by_mesh_axis",
90+
"0.8.0",
91+
)
8692

8793

8894
del jax
8995
del Any
90-
del _FakeJaxModule
96+
del _FakeJaxFunction

0 commit comments

Comments
 (0)