@@ -108,9 +108,11 @@ def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str
108108 def reshape_for_diloco (arr ):
109109 batch_dim , * example_shape = arr .shape
110110 diloco_shape = (num_diloco_replicas , batch_dim // num_diloco_replicas , * example_shape )
111- s = arr .sharding
112- s = jax .sharding .NamedSharding (mesh = s .mesh , spec = extend_pspec (s .spec ))
113- return jax .lax .with_sharding_constraint (jnp .reshape (arr , shape = diloco_shape ), s )
111+ if hasattr (arr , "sharding" ):
112+ s = arr .sharding
113+ s = jax .sharding .NamedSharding (mesh = s .mesh , spec = extend_pspec (s .spec ))
114+ return jax .lax .with_sharding_constraint (jnp .reshape (arr , shape = diloco_shape ), s )
115+ return jnp .reshape (arr , shape = diloco_shape )
114116
115117 return jax .tree .map (reshape_for_diloco , pytree )
116118
@@ -166,9 +168,11 @@ def add_diloco_dim(x):
166168
167169 # Build shardings
168170 inner_state_shardings = add_diloco_to_sharding (state_mesh_shardings )
169- outer_opt_state_sharding = jax .tree .map (
170- lambda _ : jax .sharding .NamedSharding (mesh , jax .sharding .PartitionSpec ()),
171- outer_opt_state ,
171+ # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState())
172+ # We shard the momentum trace the same way as the parameters.
173+ outer_opt_state_sharding = (
174+ optax .TraceState (trace = state_mesh_shardings .params ),
175+ optax .EmptyState (),
172176 )
173177 diloco_state_shardings = DiLoCoTrainState (
174178 inner_state = inner_state_shardings ,
@@ -183,6 +187,7 @@ def add_diloco_dim(x):
183187def build_diloco_state (
184188 config : "pyconfig.HyperParameters" ,
185189 initialize_state : Callable [[], train_state .TrainState ],
190+ mesh : jax .sharding .Mesh | None = None ,
186191) -> tuple [DiLoCoTrainState , PyTree ]:
187192 """Given a non-DiLoCo train state, construct a DiLoCo training state."""
188193 outer_optimizer = optax .sgd (
@@ -195,7 +200,10 @@ def build_diloco_state(
195200 def init_diloco_state () -> tuple [DiLoCoTrainState , PyTree ]:
196201 state = initialize_state ()
197202 # Inner state must be broadcast across clients.
198- inner_state = drjax .broadcast (state )
203+ # Pass mesh explicitly because jax.set_mesh() uses a different thread-local
204+ # than pxla.thread_resources (which drjax reads), so drjax cannot find the
205+ # mesh automatically when jax.set_mesh is used.
206+ inner_state = drjax .broadcast (state , mesh = mesh )
199207 # Outer state retains a single copy of the model parameters and optimizer state.
200208 outer_params = state .params
201209 outer_opt_state = outer_optimizer .init (outer_params )
@@ -211,6 +219,7 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
211219def build_diloco_train_step (
212220 config : pyconfig .HyperParameters ,
213221 train_step : Callable [[train_state .TrainState , Batch , PRNGKey ], tuple [train_state .TrainState , Metrics ]],
222+ mesh : jax .sharding .Mesh | None = None ,
214223) -> Callable [[DiLoCoTrainState , Batch , PRNGKey ], tuple [DiLoCoTrainState , Metrics ]]:
215224 """Convert a local state and train step into DiLoCo-compatible versions.
216225
@@ -234,7 +243,7 @@ def build_diloco_train_step(
234243 def synchronize (state ):
235244 # Calculate the delta between the current replica's state and the global
236245 # state (since last synchronization).
237- broadcast_outer_params = drjax .broadcast (state .params )
246+ broadcast_outer_params = drjax .broadcast (state .params , mesh = mesh )
238247 model_delta = jax .tree .map (lambda x , y : y - x , state .inner_state .params , broadcast_outer_params )
239248 # Treat the average delta as the outer optimizer's gradient and apply to
240249 # the global (outer) model params.
@@ -244,7 +253,7 @@ def synchronize(state):
244253 # Replace inner model params with the new global model params.
245254 # NOTE: inner optimizer state is retained despite the change in parameters,
246255 # see section 6.1 in https://arxiv.org/pdf/2311.08105.
247- new_inner_state = drjax .map_fn (lambda state : state .replace (params = new_outer_params ), state .inner_state )
256+ new_inner_state = drjax .map_fn (lambda state : state .replace (params = new_outer_params ), state .inner_state , mesh = mesh )
248257 return state .replace (
249258 params = new_outer_params ,
250259 outer_opt_state = new_opt_state ,
@@ -259,8 +268,8 @@ def typed_reduce_mean(in_tree):
259268 @drjax .program (placements = {"diloco" : config .num_diloco_replicas })
260269 def diloco_train_step (state , batch , prng ):
261270 # Broadcast the RNG across replicas.
262- broadcast_rng = drjax .broadcast (prng )
263- inner_state , metrics = drjax .map_fn (train_step , (state .inner_state , batch , broadcast_rng ))
271+ broadcast_rng = drjax .broadcast (prng , mesh = mesh )
272+ inner_state , metrics = drjax .map_fn (train_step , (state .inner_state , batch , broadcast_rng ), mesh = mesh )
264273 avg_metrics = typed_reduce_mean (metrics )
265274 state = state .replace (
266275 inner_state = inner_state ,
0 commit comments