Skip to content

Commit 5478bad

Browse files
Merge pull request #3471 from AI-Hypercomputer:mohit/diloco_fixes
PiperOrigin-RevId: 891891432
2 parents c30ada0 + 2acabe5 commit 5478bad

8 files changed

Lines changed: 86 additions & 33 deletions

File tree

src/maxtext/common/data_loader.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,10 @@ def load_next_batch_pre_sharding(self):
7171

7272
def load_next_batch(self, *args, **kwargs):
7373
"""Loads the next batch with sharding hint"""
74-
example_batch = jax.device_put(
75-
self.load_next_batch_pre_sharding(),
76-
self.input_data_shardings,
77-
)
74+
example_batch = self.load_next_batch_pre_sharding()
7875
if self.config.enable_diloco:
7976
example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch)
80-
return example_batch
77+
return jax.device_put(example_batch, self.input_data_shardings)
8178

8279
def check_example_batch(self):
8380
if self.config.max_checkify:
@@ -157,6 +154,8 @@ def _slice(data):
157154
self.buffer_start = slice_end
158155
output = jax.tree.map(_slice, self.batch_buffer)
159156
self.rampup_active = rampup_manager.update()
157+
if self.config.enable_diloco:
158+
output = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, output)
160159
return jax.device_put(output, self.input_data_shardings)
161160

162161

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
59+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
6060
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
@@ -83,4 +83,5 @@ logical_axis_rules: [
8383
['mlp', ['fsdp_transpose', 'expert']],
8484
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
8585
['mlp_only_tensor', ['expert']],
86+
['diloco', 'diloco'],
8687
]

src/maxtext/configs/types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2679,8 +2679,33 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
26792679
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
26802680

26812681
# Diloco params
2682+
# Resolve dcn_diloco_parallelism=-1 if left unspecified, using the same convention as dcn_data_parallelism.
2683+
# num_diloco_replicas must be computed after this resolution, so we resolve it here rather than
2684+
# relying on fill_unspecified_mesh_axes (which runs later during mesh creation).
2685+
if self.dcn_diloco_parallelism == -1:
2686+
other_dcn_product = prod(v for v in self.dcn_parallelism if v != -1)
2687+
assert other_dcn_product > 0 and self.num_slices % other_dcn_product == 0, (
2688+
f"Cannot resolve dcn_diloco_parallelism=-1: num_slices={self.num_slices} is not divisible "
2689+
f"by the product of other DCN parallelism values ({other_dcn_product})."
2690+
)
2691+
self.dcn_diloco_parallelism = self.num_slices // other_dcn_product
2692+
# Keep dcn_parallelism list consistent with the resolved value.
2693+
diloco_idx = self.dcn_parallelism.index(-1)
2694+
self.dcn_parallelism[diloco_idx] = self.dcn_diloco_parallelism
26822695
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)
26832696

2697+
# (b/496973624) use_tokamax_gmm is incompatible with enable_diloco: drjax.map_fn wraps
2698+
# the train step in jax.vmap over the diloco axis, which causes JAX to batch through
2699+
# lax.scan (layer scan).
2700+
# Tokamax's vmap_rule then tries to reconstruct GroupSizes with a batched 2-D value, but
2701+
# GroupSizes.__post_init__ requires exactly a 1-D shape.
2702+
if self.enable_diloco and self.use_tokamax_gmm:
2703+
raise ValueError(
2704+
"use_tokamax_gmm=True is not compatible with enable_diloco=True due to a known "
2705+
"incompatibility between tokamax's GroupSizes vmap_rule and JAX's scan batching. "
2706+
"Please set use_tokamax_gmm=False."
2707+
)
2708+
26842709
# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
26852710
if isinstance(self.decoder_block, str):
26862711
self.decoder_block = DecoderBlockType(self.decoder_block.lower())

src/maxtext/trainers/diloco/diloco.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str
108108
def reshape_for_diloco(arr):
109109
batch_dim, *example_shape = arr.shape
110110
diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape)
111-
s = arr.sharding
112-
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
113-
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)
111+
if hasattr(arr, "sharding"):
112+
s = arr.sharding
113+
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
114+
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)
115+
return jnp.reshape(arr, shape=diloco_shape)
114116

115117
return jax.tree.map(reshape_for_diloco, pytree)
116118

@@ -166,9 +168,11 @@ def add_diloco_dim(x):
166168

167169
# Build shardings
168170
inner_state_shardings = add_diloco_to_sharding(state_mesh_shardings)
169-
outer_opt_state_sharding = jax.tree.map(
170-
lambda _: jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
171-
outer_opt_state,
171+
# Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState())
172+
# We shard the momentum trace the same way as the parameters.
173+
outer_opt_state_sharding = (
174+
optax.TraceState(trace=state_mesh_shardings.params),
175+
optax.EmptyState(),
172176
)
173177
diloco_state_shardings = DiLoCoTrainState(
174178
inner_state=inner_state_shardings,
@@ -183,6 +187,7 @@ def add_diloco_dim(x):
183187
def build_diloco_state(
184188
config: "pyconfig.HyperParameters",
185189
initialize_state: Callable[[], train_state.TrainState],
190+
mesh: jax.sharding.Mesh | None = None,
186191
) -> tuple[DiLoCoTrainState, PyTree]:
187192
"""Given a non-DiLoCo train state, construct a DiLoCo training state."""
188193
outer_optimizer = optax.sgd(
@@ -195,7 +200,10 @@ def build_diloco_state(
195200
def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
196201
state = initialize_state()
197202
# Inner state must be broadcast across clients.
198-
inner_state = drjax.broadcast(state)
203+
# Pass mesh explicitly because jax.set_mesh() uses a different thread-local
204+
# than pxla.thread_resources (which drjax reads), so drjax cannot find the
205+
# mesh automatically when jax.set_mesh is used.
206+
inner_state = drjax.broadcast(state, mesh=mesh)
199207
# Outer state retains a single copy of the model parameters and optimizer state.
200208
outer_params = state.params
201209
outer_opt_state = outer_optimizer.init(outer_params)
@@ -211,6 +219,7 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
211219
def build_diloco_train_step(
212220
config: pyconfig.HyperParameters,
213221
train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]],
222+
mesh: jax.sharding.Mesh | None = None,
214223
) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]:
215224
"""Convert a local state and train step into DiLoCo-compatible versions.
216225
@@ -234,7 +243,7 @@ def build_diloco_train_step(
234243
def synchronize(state):
235244
# Calculate the delta between the current replica's state and the global
236245
# state (since last synchronization).
237-
broadcast_outer_params = drjax.broadcast(state.params)
246+
broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh)
238247
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
239248
# Treat the average delta as the outer optimizer's gradient and apply to
240249
# the global (outer) model params.
@@ -244,7 +253,7 @@ def synchronize(state):
244253
# Replace inner model params with the new global model params.
245254
# NOTE: inner optimizer state is retained despite the change in parameters,
246255
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
247-
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
256+
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh)
248257
return state.replace(
249258
params=new_outer_params,
250259
outer_opt_state=new_opt_state,
@@ -259,8 +268,8 @@ def typed_reduce_mean(in_tree):
259268
@drjax.program(placements={"diloco": config.num_diloco_replicas})
260269
def diloco_train_step(state, batch, prng):
261270
# Broadcast the RNG across replicas.
262-
broadcast_rng = drjax.broadcast(prng)
263-
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
271+
broadcast_rng = drjax.broadcast(prng, mesh=mesh)
272+
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh)
264273
avg_metrics = typed_reduce_mean(metrics)
265274
state = state.replace(
266275
inner_state=inner_state,

src/maxtext/trainers/pre_train/train.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -507,19 +507,18 @@ def train_loop(config, recorder, state=None):
507507

508508
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
509509

510-
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
511-
config,
512-
model,
513-
mesh,
514-
state,
515-
state_mesh_shardings,
516-
train_step,
517-
eval_step,
518-
eval_data_iterator,
519-
params_shardings,
520-
)
521-
522-
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
510+
with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
511+
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
512+
config,
513+
model,
514+
mesh,
515+
state,
516+
state_mesh_shardings,
517+
train_step,
518+
eval_step,
519+
eval_data_iterator,
520+
params_shardings,
521+
)
523522
shaped_batch = maxtext_utils.get_shaped_batch(config)
524523
if config.shard_optimizer_over_data:
525524
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def jit_and_compile(
132132
logical_axis_rules,
133133
):
134134
"""Jit, lower, and compile func."""
135-
with jax.set_mesh(mesh), logical_axis_rules:
135+
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
136+
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
137+
with jax.set_mesh(mesh), mesh, logical_axis_rules:
136138
jitted = jax.jit(
137139
func,
138140
in_shardings=in_shardings,

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def jit_train_and_eval_step(
162162
"""Returns a JIT-compiled train and eval step function."""
163163
if config.enable_diloco:
164164
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
165-
train_step = diloco.build_diloco_train_step(config, train_step_partial)
165+
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
166166
data_sharding = sharding.get_input_data_sharding(config, mesh)
167167
p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings)
168168
p_eval_step = None
@@ -229,7 +229,7 @@ def setup_train_loop(config, recorder, devices=None):
229229

230230
if config.enable_diloco:
231231
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
232-
state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state)
232+
state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state, mesh=mesh)
233233

234234
# create state_mesh_shardings for the DilocoState
235235
inner_state_shardings = diloco.add_diloco_to_sharding(state_mesh_shardings)

tests/unit/diloco_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,24 @@ def loss_fn(params, batch):
268268
# synchronization).
269269
chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params)
270270

271+
@pytest.mark.cpu_only
272+
def test_diloco_qwen3_moe_two_slices(self):
273+
temp_dir = gettempdir()
274+
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle")
275+
train_compile_main(
276+
(
277+
None,
278+
get_test_config_path(),
279+
f"compiled_trainstep_file={compiled_trainstep_file}",
280+
"compile_topology=tpu7x-16",
281+
"compile_topology_num_slices=2",
282+
"ici_fsdp_parallelism=-1",
283+
"dcn_diloco_parallelism=2",
284+
"enable_diloco=true",
285+
"model_name=qwen3-30b-a3b",
286+
)
287+
)
288+
271289
@pytest.mark.tpu_only
272290
def test_diloco_two_slices(self):
273291
temp_dir = gettempdir()

0 commit comments

Comments
 (0)