Skip to content

Commit f6c1870

Browse files
lukebaumanncopybara-github
authored andcommitted
Use Pathways _transfer_to_sharding for resharding in experimental reshard API.
The experimental `reshard` function now leverages `pw_jax.jaxlib_pathways._transfer_to_shardings` for array resharding (if available), removing the custom `ReshardingPlanWrapper` and its dependency on `plugin_executable`. This simplifies the implementation and aligns with internal Pathways mechanisms. The `cache_resharding_plans` option has also been removed. This requires the latest JAX version because the ifrt_proxy API for ReshardArrays was recently added. To ensure backwards compatibility, if `_transfer_to_sharding` is not supported, the sidechannel API will be used. PiperOrigin-RevId: 852643657
1 parent b5279c6 commit f6c1870

2 files changed

Lines changed: 125 additions & 37 deletions

File tree

pathwaysutils/experimental/reshard.py

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import base64
1717
import collections
1818
import json
19-
from typing import Any, Dict, Sequence
19+
from typing import Any, Callable, Dict, Mapping, Sequence
2020

2121
import jax
22+
from pathwaysutils import jax as pw_jax
2223
from pathwaysutils import lru_cache
2324
from pathwaysutils import plugin_executable
2425

@@ -99,37 +100,16 @@ def _get_resharding_plan(
99100
_get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan)
100101

101102

102-
def reshard(
103+
def _reshard(
103104
x: Any,
104105
sharding: jax.sharding.Sharding | Any,
105106
*,
106107
donate: bool = False,
107-
may_alias: bool | None = None, # pylint: disable=unused-argument
108-
cache_resharding_plans: bool = False,
108+
may_alias: bool | None,
109+
jax_array_reshard_fn: Callable[..., Any],
110+
**kwargs,
109111
) -> Any:
110-
"""Reshards `x` to `sharding`.
111-
112-
Args:
113-
x: An array, scalar, or (nested) standard Python container thereof.
114-
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
115-
(must be a tree prefix of `x`), representing the device(s) and sharding to
116-
which `x` should be sharded to. The result will be committed to the
117-
device(s) of the sharding.
118-
donate: If `True`, donate all input arrays, which may reduce the amount of
119-
memory needed for resharding. Buffers donated to resharding should not be
120-
reused.
121-
may_alias: If `True`, may alias the input array with the output array. May
122-
reduce the amount of memory needed for resharding. Not used at the moment.
123-
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
124-
recreating plans for the same resharding operation. May improve
125-
performance for use cases where the same resharding operation is done many
126-
times. May degrade performance if most reshardings operations are
127-
different, since the cache will cause Pathways Components to remain loaded
128-
for each cached plan. `False` by default.
129-
130-
Returns:
131-
A copy of `x` whose sharding is `sharding`.
132-
"""
112+
"""Reshards `x` to `sharding`."""
133113
flat_x, tree_def = jax.tree.flatten(x)
134114
flat_sharding = jax.api_util.flatten_axes(
135115
"reshard sharding", tree_def, sharding
@@ -176,17 +156,9 @@ def reshard(
176156
)
177157

178158
for array_info in jax_arrays.values():
179-
get_resharding_plan_func = (
180-
_get_resharding_plan_cached
181-
if cache_resharding_plans
182-
else _get_resharding_plan
159+
array_info["arrays"] = jax_array_reshard_fn(
160+
array_info, donate=donate, **kwargs
183161
)
184-
array_info["arrays"] = get_resharding_plan_func(
185-
tuple(arr.aval for arr in array_info["arrays"]),
186-
tuple(arr.sharding for arr in array_info["arrays"]),
187-
tuple(array_info["dst_shardings"]),
188-
donate,
189-
).execute(tuple(array_info["arrays"]))
190162

191163
result = [None] * len(flat_x)
192164
for arr, idx in zip(
@@ -198,3 +170,80 @@ def reshard(
198170
result[idx] = arr
199171

200172
return jax.tree.unflatten(tree_def, result)
173+
174+
175+
def _sidechannel_jax_array_reshard(
176+
array_info: Mapping[str, Any], *, donate: bool, cache_resharding_plans: bool
177+
) -> Sequence[jax.Array]:
178+
get_resharding_plan_func = (
179+
_get_resharding_plan_cached
180+
if cache_resharding_plans
181+
else _get_resharding_plan
182+
)
183+
return get_resharding_plan_func(
184+
tuple(arr.aval for arr in array_info["arrays"]),
185+
tuple(arr.sharding for arr in array_info["arrays"]),
186+
tuple(array_info["dst_shardings"]),
187+
donate,
188+
).execute(tuple(array_info["arrays"]))
189+
190+
191+
def _ifrt_jax_array_reshard(
192+
array_info: Mapping[str, Any], *, donate: bool
193+
) -> Sequence[jax.Array]:
194+
return pw_jax.transfer_to_shardings(
195+
tuple(arr for arr in array_info["arrays"]),
196+
tuple(array_info["dst_shardings"]),
197+
donate,
198+
)
199+
200+
201+
def reshard(
202+
x: Any,
203+
sharding: jax.sharding.Sharding | Any,
204+
*,
205+
donate: bool = False,
206+
may_alias: bool | None = None,
207+
cache_resharding_plans: bool = False,
208+
) -> Any:
209+
"""Reshards `x` to `sharding`.
210+
211+
Args:
212+
x: An array, scalar, or (nested) standard Python container thereof.
213+
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
214+
(must be a tree prefix of `x`), representing the device(s) and sharding to
215+
which `x` should be sharded to. The result will be committed to the
216+
device(s) of the sharding.
217+
donate: If `True`, donate all input arrays, which may reduce the amount of
218+
memory needed for resharding. Buffers donated to resharding should not be
219+
reused.
220+
may_alias: If `True`, may alias the input array with the output array. May
221+
reduce the amount of memory needed for resharding. Not used at the moment.
222+
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
223+
recreating plans for the same resharding operation. May improve
224+
performance for use cases where the same resharding operation is done many
225+
times. May degrade performance if most reshardings operations are
226+
different, since the cache will cause Pathways Components to remain loaded
227+
for each cached plan. `False` by default. Only used when IFRT resharding
228+
is not available.
229+
230+
Returns:
231+
A copy of `x` whose sharding is `sharding`.
232+
"""
233+
if pw_jax.ifrt_reshard_available():
234+
return _reshard(
235+
x,
236+
sharding,
237+
donate=donate,
238+
may_alias=may_alias,
239+
jax_array_reshard_fn=_ifrt_jax_array_reshard,
240+
)
241+
else:
242+
return _reshard(
243+
x,
244+
sharding,
245+
donate=donate,
246+
may_alias=may_alias,
247+
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
248+
cache_resharding_plans=cache_resharding_plans,
249+
)

pathwaysutils/jax/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
`pathwaysutils`'s compatibility window.
1818
"""
1919

20+
import functools
2021
from typing import Any
22+
2123
import jax
2224

2325

@@ -91,6 +93,43 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
9193
)
9294

9395

96+
try:
97+
# jax>=0.8.3
98+
# The import may fail if the JAX version is not new enough.
99+
from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top
100+
101+
transfer_to_shardings = jaxlib_pathways._transfer_to_shardings
102+
103+
del jaxlib_pathways
104+
105+
except ImportError:
106+
# jax<0.8.3
107+
transfer_to_shardings = _FakeJaxFunction(
108+
"jax.jaxlib._pathways._transfer_to_shardings",
109+
"0.8.3",
110+
)
111+
112+
113+
@functools.lru_cache(maxsize=1)
114+
def ifrt_reshard_available() -> bool:
115+
"""Checks if transfer_to_shardings is available."""
116+
try:
117+
import jax # pylint: disable=g-import-not-at-top
118+
119+
transfer_to_shardings(
120+
[jax.numpy.array([0])],
121+
[jax.sharding.SingleDeviceSharding(jax.devices()[0])],
122+
)
123+
124+
except (ImportError, NameError, jax.errors.JaxRuntimeError):
125+
return False
126+
else:
127+
return True
128+
finally:
129+
del jax
130+
131+
94132
del jax
95133
del Any
96134
del _FakeJaxFunction
135+
del functools

0 commit comments

Comments
 (0)