1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415"""Utilities for partitioning."""
1516
1617import abc
2122import flax .linen as nn
2223import jax
2324import numpy as np
25+ # FIXED: Use the public experimental module available in JAX 0.4.30
26+ from recml .core .training import mesh_context
2427
2528
2629PyTree = Any
@@ -67,7 +70,8 @@ class DataParallelPartitioner(Partitioner):
6770 """Data parallel partitioner."""
6871
6972 def __init__ (self , data_axis : str = "batch" ):
70- self .mesh = jax .make_mesh ((jax .device_count (),), (data_axis ,))
73+ devices = jax .devices ()
74+ self .mesh = jax .sharding .Mesh (devices , (data_axis ,))
7175 self .data_sharding = jax .sharding .NamedSharding (
7276 self .mesh , jax .sharding .PartitionSpec (data_axis )
7377 )
@@ -107,8 +111,10 @@ def _shard(x: np.ndarray) -> jax.Array:
107111 def partition_init (
108112 self , init_fn : CreateStateFn , * , abstract_batch : PyTree | None = None
109113 ) -> CreateStateFn :
110- with jax .sharding .use_mesh (self .mesh ):
114+ # FIXED: Use 'with self.mesh'
115+ with self .mesh :
111116 if abstract_batch is not None :
117+ mesh_context .set_global_mesh (self .mesh )
112118 abstract_state = jax .eval_shape (init_fn , abstract_batch )
113119 specs = nn .get_partition_spec (abstract_state )
114120 self .state_sharding = jax .tree .map (
@@ -117,7 +123,8 @@ def partition_init(
117123 init_fn = jax .jit (init_fn , out_shardings = self .state_sharding )
118124
119125 def _wrapped_init (batch : PyTree ) -> State :
120- with jax .sharding .use_mesh (self .mesh ):
126+ # FIXED: Use 'with self.mesh'
127+ with self .mesh :
121128 state = init_fn (batch )
122129 state = _maybe_unbox_state (state )
123130 return state
@@ -130,15 +137,18 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130137 jit_kws ["out_shardings" ] = (self .state_sharding , None )
131138 jit_kws ["donate_argnums" ] = (1 ,)
132139
133- with jax .sharding .use_mesh (self .mesh ):
140+ # FIXED: Use 'with self.mesh' and legacy bridge
141+ with self .mesh :
142+ mesh_context .set_global_mesh (self .mesh )
134143 step_fn = jax .jit (
135144 fn ,
136145 in_shardings = (self .data_sharding , self .state_sharding ),
137146 ** jit_kws ,
138147 )
139148
140149 def _wrapped_step (batch : PyTree , state : State ) -> Any :
141- with jax .sharding .use_mesh (self .mesh ):
150+ # FIXED: Use 'with self.mesh'
151+ with self .mesh :
142152 return step_fn (batch , state )
143153
144154 return _wrapped_step
@@ -190,7 +200,8 @@ def __init__(
190200 if axis_sizes [0 ] == - 1 :
191201 axis_sizes [0 ] = len (devices ) // math .prod (axis_sizes [1 :])
192202
193- self .mesh = jax .make_mesh (axis_sizes , axis_names , devices = devices )
203+ # self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
204+ self .mesh = jax .sharding .Mesh (devices , axis_names )
194205 self .rules = rules
195206 self .aot_compile = aot_compile
196207 self .options = options
@@ -213,12 +224,6 @@ def __init__(
213224 self .abstract_batch = None
214225 self .abstract_state = None
215226
216- @property
217- def mesh_context_manager (
218- self ,
219- ) -> Callable [[jax .sharding .Mesh ], ContextManager [None ]]:
220- return jax .sharding .use_mesh
221-
222227 def shard_inputs (self , inputs : PyTree ) -> PyTree :
223228 def _shard (x : np .ndarray ) -> jax .Array :
224229 return jax .make_array_from_process_local_data (self .data_sharding , x )
@@ -234,7 +239,10 @@ def partition_init(
234239 " model parallel partitioner."
235240 )
236241
237- with self .mesh_context_manager (self .mesh ):
242+ # FIXED: Use 'with self.mesh' directly
243+ with self .mesh :
244+ # FIXED: Legacy bridge
245+ mesh_context .set_global_mesh (self .mesh )
238246 abstract_state = jax .eval_shape (init_fn , abstract_batch )
239247 specs = nn .get_partition_spec (abstract_state )
240248
@@ -247,7 +255,8 @@ def partition_init(
247255 compiled_init_fn = jax .jit (init_fn , out_shardings = state_sharding )
248256
249257 def _init (batch : PyTree ) -> State :
250- with self .mesh_context_manager (self .mesh ):
258+ # FIXED: Use 'with self.mesh' directly
259+ with self .mesh :
251260 state = compiled_init_fn (batch )
252261 state = _maybe_unbox_state (state )
253262 return state
@@ -265,7 +274,9 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
265274 else :
266275 jit_kws ["out_shardings" ] = None
267276
268- with self .mesh_context_manager (self .mesh ):
277+ # FIXED: Use 'with self.mesh' directly and legacy bridge
278+ with self .mesh :
279+ mesh_context .set_global_mesh (self .mesh )
269280 step_fn = jax .jit (
270281 fn ,
271282 in_shardings = (self .data_sharding , self .state_sharding ),
@@ -286,7 +297,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
286297 )
287298
288299 def _step (batch : PyTree , state : State ) -> Any :
289- with self .mesh_context_manager (self .mesh ):
300+ # FIXED: Use 'with self.mesh' directly
301+ with self .mesh :
290302 return step_fn (batch , state )
291303
292304 return _step
@@ -302,4 +314,4 @@ def _maybe_unbox(x: Any) -> Any:
302314 _maybe_unbox ,
303315 x ,
304316 is_leaf = lambda k : isinstance (k , nn .Partitioned ),
305- )
317+ )
0 commit comments