1616import base64
1717import collections
1818import json
19- from typing import Any , Dict , Sequence
19+ from typing import Any , Callable , Dict , Mapping , Sequence
2020
2121import jax
22+ from pathwaysutils import jax as pw_jax
2223from pathwaysutils import lru_cache
2324from pathwaysutils import plugin_executable
2425
@@ -99,37 +100,16 @@ def _get_resharding_plan(
99100_get_resharding_plan_cached = lru_cache .lru_cache ()(_get_resharding_plan )
100101
101102
102- def reshard (
103+ def _reshard (
103104 x : Any ,
104105 sharding : jax .sharding .Sharding | Any ,
105106 * ,
106107 donate : bool = False ,
107- may_alias : bool | None = None , # pylint: disable=unused-argument
108- cache_resharding_plans : bool = False ,
108+ may_alias : bool | None ,
109+ jax_array_reshard_fn : Callable [..., Any ],
110+ ** kwargs ,
109111) -> Any :
110- """Reshards `x` to `sharding`.
111-
112- Args:
113- x: An array, scalar, or (nested) standard Python container thereof.
114- sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
115- (must be a tree prefix of `x`), representing the device(s) and sharding to
116- which `x` should be sharded to. The result will be committed to the
117- device(s) of the sharding.
118- donate: If `True`, donate all input arrays, which may reduce the amount of
119- memory needed for resharding. Buffers donated to resharding should not be
120- reused.
121- may_alias: If `True`, may alias the input array with the output array. May
122- reduce the amount of memory needed for resharding. Not used at the moment.
123- cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
124- recreating plans for the same resharding operation. May improve
125- performance for use cases where the same resharding operation is done many
126- times. May degrade performance if most reshardings operations are
127- different, since the cache will cause Pathways Components to remain loaded
128- for each cached plan. `False` by default.
129-
130- Returns:
131- A copy of `x` whose sharding is `sharding`.
132- """
112+ """Reshards `x` to `sharding`."""
133113 flat_x , tree_def = jax .tree .flatten (x )
134114 flat_sharding = jax .api_util .flatten_axes (
135115 "reshard sharding" , tree_def , sharding
@@ -176,17 +156,9 @@ def reshard(
176156 )
177157
178158 for array_info in jax_arrays .values ():
179- get_resharding_plan_func = (
180- _get_resharding_plan_cached
181- if cache_resharding_plans
182- else _get_resharding_plan
159+ array_info ["arrays" ] = jax_array_reshard_fn (
160+ array_info , donate = donate , ** kwargs
183161 )
184- array_info ["arrays" ] = get_resharding_plan_func (
185- tuple (arr .aval for arr in array_info ["arrays" ]),
186- tuple (arr .sharding for arr in array_info ["arrays" ]),
187- tuple (array_info ["dst_shardings" ]),
188- donate ,
189- ).execute (tuple (array_info ["arrays" ]))
190162
191163 result = [None ] * len (flat_x )
192164 for arr , idx in zip (
@@ -198,3 +170,80 @@ def reshard(
198170 result [idx ] = arr
199171
200172 return jax .tree .unflatten (tree_def , result )
173+
174+
175+ def _sidechannel_jax_array_reshard (
176+ array_info : Mapping [str , Any ], * , donate : bool , cache_resharding_plans : bool
177+ ) -> Sequence [jax .Array ]:
178+ get_resharding_plan_func = (
179+ _get_resharding_plan_cached
180+ if cache_resharding_plans
181+ else _get_resharding_plan
182+ )
183+ return get_resharding_plan_func (
184+ tuple (arr .aval for arr in array_info ["arrays" ]),
185+ tuple (arr .sharding for arr in array_info ["arrays" ]),
186+ tuple (array_info ["dst_shardings" ]),
187+ donate ,
188+ ).execute (tuple (array_info ["arrays" ]))
189+
190+
191+ def _ifrt_jax_array_reshard (
192+ array_info : Mapping [str , Any ], * , donate : bool
193+ ) -> Sequence [jax .Array ]:
194+ return pw_jax .transfer_to_shardings (
195+ tuple (arr for arr in array_info ["arrays" ]),
196+ tuple (array_info ["dst_shardings" ]),
197+ donate ,
198+ )
199+
200+
201+ def reshard (
202+ x : Any ,
203+ sharding : jax .sharding .Sharding | Any ,
204+ * ,
205+ donate : bool = False ,
206+ may_alias : bool | None = None ,
207+ cache_resharding_plans : bool = False ,
208+ ) -> Any :
209+ """Reshards `x` to `sharding`.
210+
211+ Args:
212+ x: An array, scalar, or (nested) standard Python container thereof.
213+ sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
214+ (must be a tree prefix of `x`), representing the device(s) and sharding to
215+ which `x` should be sharded to. The result will be committed to the
216+ device(s) of the sharding.
217+ donate: If `True`, donate all input arrays, which may reduce the amount of
218+ memory needed for resharding. Buffers donated to resharding should not be
219+ reused.
220+ may_alias: If `True`, may alias the input array with the output array. May
221+ reduce the amount of memory needed for resharding. Not used at the moment.
222+ cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
223+ recreating plans for the same resharding operation. May improve
224+ performance for use cases where the same resharding operation is done many
225+ times. May degrade performance if most reshardings operations are
226+ different, since the cache will cause Pathways Components to remain loaded
227+ for each cached plan. `False` by default. Only used when IFRT resharding
228+ is not available.
229+
230+ Returns:
231+ A copy of `x` whose sharding is `sharding`.
232+ """
233+ if pw_jax .ifrt_reshard_available ():
234+ return _reshard (
235+ x ,
236+ sharding ,
237+ donate = donate ,
238+ may_alias = may_alias ,
239+ jax_array_reshard_fn = _ifrt_jax_array_reshard ,
240+ )
241+ else :
242+ return _reshard (
243+ x ,
244+ sharding ,
245+ donate = donate ,
246+ may_alias = may_alias ,
247+ jax_array_reshard_fn = _sidechannel_jax_array_reshard ,
248+ cache_resharding_plans = cache_resharding_plans ,
249+ )
0 commit comments