Skip to content

Commit 09ee677

Browse files
committed
add cse remat
1 parent 552621b commit 09ee677

4 files changed

Lines changed: 95 additions & 88 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ xprof_e2e_enable_fw_power_level_event: False
968968
xprof_e2e_enable_fw_thermal_event: False
969969
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
970970

971-
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
971+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
972972
debug_sharding: False # Prints model weights sharding info
973973

974974
# Checkpoint Structured logging

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,47 @@
2020
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
2121
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
2222
# second, it may be required for DCN communication.
23+
#
24+
# The `context` axis is used for supporting fractional per device batch size
2325
#
2426
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
2527
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
2628
# store prefetched weights.
27-
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
28-
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
29+
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
30+
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
2931
logical_axis_rules: [
3032
['activation_batch', ['data', 'fsdp', 'expert']],
3133
['activation_batch_moe', ['data', 'fsdp', 'expert']],
3234
['activation_batch_no_exp', ['data', 'fsdp']],
3335
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3436
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
35-
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
37+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
3638
['activation_heads', ['tensor']],
3739
['activation_kv_heads', ['tensor']],
38-
['activation_length', ['expert']],
39-
['activation_attn_length', ['expert']],
40-
['activation_q_length', ['expert']],
40+
['activation_length', ['context', 'expert']],
41+
['activation_attn_length', ['context', 'expert']],
42+
['activation_q_length', ['context', 'expert']],
4143
['activation_attn_embed', ['tensor']],
4244
['activation_embed', ['tensor']],
4345
['activation_embed_moe', ['tensor']],
4446
['activation_mlp', ['tensor']],
4547
['activation_kv', ['tensor']],
46-
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
4748
['activation_kv_batch', ['data', 'fsdp', 'expert']],
4849
['activation_kv_batch_no_exp', ['data', 'fsdp']],
4950
['activation_kv_head_dim', ['tensor']],
5051
['activation_vocab', ['tensor']],
5152
['activation_stage', 'stage'],
5253
['activation_exp', ['expert']],
53-
['decode_batch', ['data', 'fsdp', 'expert']],
5454
['mlp', ['tensor']],
5555
['mlp_no_fsdp', ['tensor']],
5656
['vocab', ['tensor']],
5757
['heads', ['tensor']],
5858
['q_heads', ['tensor']],
5959
['kv_heads', ['tensor']],
60-
['embed', ['fsdp', 'expert']],
60+
['embed', ['fsdp', 'expert']], # remove context from embed sharding
6161
['embed_moe', ['fsdp', 'expert']],
6262
['embed_no_exp', ['fsdp']],
6363
['embed_no_exp_moe', ['fsdp']],
64-
['embed_moe', ['fsdp']],
6564
['q_lora', ['fsdp']],
6665
['kv_lora', ['fsdp']],
6766
['norm', ['tensor']],

src/maxtext/layers/pipeline.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,11 +1173,12 @@ def from_repeat_weights_to_bsw(
11731173
self,
11741174
repeat_weights,
11751175
physical_partition_spec,
1176-
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
1176+
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"),
1177+
# TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying')
11771178
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
11781179
):
11791180
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
1180-
axes_to_remove = ["fsdp", "fsdp_transpose"]
1181+
axes_to_remove = ["fsdp", "fsdp_transpose", "context"]
11811182
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove)
11821183

11831184
def _from_repeat_weights_to_bsw_shardmap(
@@ -1244,20 +1245,7 @@ def _apply_sharding_hint(weight, pspec):
12441245
return _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather=axes_to_gather)
12451246
return _from_repeat_weights_to_bsw_hint(repeat_weights)
12461247

1247-
def both_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1248-
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
1249-
1250-
By gathering weights for `loop_iteration + 1` right now, the network communication
1251-
can overlap with the compute happening in `loop_iteration`. The dual-buffers
1252-
are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
1253-
"""
1254-
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration)
1255-
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1256-
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
1257-
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
1258-
return bsw_0, bsw_1
1259-
1260-
def one_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1248+
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12611249
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
12621250
12631251
By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1351,7 +1339,6 @@ def __call__(
13511339
segment_idx = None
13521340

13531341
loop_state, bsw = self.init_states(inputs)
1354-
weights = self.layers.variables
13551342
physical_partition_spec = logical_to_mesh(
13561343
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
13571344
)
@@ -1388,41 +1375,34 @@ def run_iteration_scannable(model, loop_state, bsw):
13881375

13891376
# base scannable function used twice for real and bubble runs
13901377
base_scannable = functools.partial(
1391-
pipeline_utils.create_rematerialized_pipeline_stage,
1378+
pipeline_utils.create_pipeline_stage,
13921379
deterministic=deterministic,
13931380
model_mode=model_mode,
13941381
logical_partition_spec=logical_partition_spec,
13951382
physical_partition_spec=physical_partition_spec,
13961383
positions=positions,
13971384
segment_ids=segment_ids,
1398-
pipeline_weights=weights,
13991385
)
14001386

14011387
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
1402-
# run_one_repeat_scannable = nn.remat(
1403-
# run_one_repeat_scannable,
1404-
# prevent_cse=True,
1405-
# policy=self.get_pipeline_remat_policy()
1406-
# )
14071388
run_bubbles_scannable = base_scannable(length=bubble_iterations)
1408-
# run_bubbles_scannable = nn.remat(
1409-
# run_bubbles_scannable,
1410-
# prevent_cse=True,
1411-
# policy=self.get_pipeline_remat_policy()
1412-
# )
14131389

14141390
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
14151391
pipeline_stage_fn=run_one_repeat_scannable,
14161392
length=self.config.num_pipeline_repeats,
1393+
remat_policy=self.get_pipeline_remat_policy(),
14171394
use_scan=self.config.scan_pipeline_repeats,
14181395
)
14191396
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
14201397
pipeline_stage_fn=run_bubbles_scannable,
14211398
length=1,
1399+
remat_policy=self.get_pipeline_remat_policy(),
14221400
use_scan=self.config.scan_pipeline_repeats,
14231401
)
1424-
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
1425-
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
1402+
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
1403+
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
1404+
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
1405+
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)
14261406

14271407
final_output = self.realign_output_microbatches(loop_state["state_io"])
14281408
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -248,24 +248,21 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248248
return run_pipeline_microbatches_custom
249249

250250

251-
def create_rematerialized_pipeline_stage(
251+
def create_pipeline_stage(
252252
length,
253253
deterministic,
254254
model_mode,
255255
logical_partition_spec,
256256
physical_partition_spec,
257257
positions,
258258
segment_ids,
259-
pipeline_weights,
260259
):
261-
"""Builds a memory-checkpointed execution block for a single pipeline stage.
260+
"""Builds an execution block for a single pipeline stage.
262261
263262
This function prepares the state for a specific chunk of pipeline execution by:
264-
1. Prefetching the required weights for the current stage/loop iteration.
265-
2. Executing `length` microbatches using either a memory-efficient `jax.lax.scan`
266-
(if `scan_pipeline_iterations` is True) or an unrolled Python `for` loop.
267-
3. Wrapping the entire stage block in `flax.linen.remat` to discard and recompute
268-
activations during the backward pass based on the model's policy.
263+
1. Prefetching the required weights (e.g., FSDP-gathered) for the current stage/loop iteration.
264+
2. Executing `length` microbatches using a memory-efficient `jax.lax.scan` via a custom VJP
265+
that manages collective communication overlap.
269266
270267
Args:
271268
length: The number of microbatches to process in this stage.
@@ -275,14 +272,27 @@ def create_rematerialized_pipeline_stage(
275272
physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
276273
positions: Position IDs for the sequence.
277274
segment_ids: Segment/Attention routing IDs for the sequence.
278-
pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
279275
280276
Returns:
281-
A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
282-
the updated `loop_state`.
277+
A function that takes `(model, carry)` and returns the updated `carry` and `None` for the scan outputs.
283278
"""
284279

285-
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
280+
def execute_pipeline_stage_flax(model, carry):
281+
"""
282+
A non-pure Flax closure of the pipeline stage.
283+
284+
This function bridges the pure JAX custom VJP logic with Flax's object-oriented
285+
lifting mechanisms. It unpacks the carry state and routes it through the pure VJP function.
286+
287+
Args:
288+
model: CircularPipeline Flax linen model instance.
289+
carry: A tuple containing (loop_state, w_curr, pipeline_weights).
290+
- loop_state: The current execution state of the pipeline.
291+
- w_curr: The gathered weights used for the current pipeline step.
292+
- pipeline_weights: The fully sharded baseline weights.
293+
"""
294+
295+
loop_state, w_curr, pipeline_weights = carry
286296

287297
scan_microbatches_fn = create_gradient_accumulation_scan(
288298
model=model,
@@ -292,71 +302,89 @@ def execute_pipeline_stage_outer(model, loop_state_and_bsw):
292302
logical_partition_spec=logical_partition_spec,
293303
)
294304

295-
remat_weight_prefetching = model.one_weight_prefetching
296-
305+
# Establish a pure function boundary to allow for custom VJP definition
297306
@jax.custom_vjp
298-
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
299-
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
307+
def execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights):
308+
return execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights)[0]
300309

301-
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
302-
loop_state, w_curr = loop_state_and_bsw
303-
# # Retrieve the specific weights needed for this pipeline chunk
304-
w_next = remat_weight_prefetching(
310+
def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
311+
# Prefetch FSDP-sharded weights for the upcoming pipeline repeat
312+
w_next = model.weight_prefetching(
305313
pipeline_weights,
306314
physical_partition_spec,
307315
loop_state["loop_iteration"],
308316
)
317+
# Construct a buffered sliding window (BSW) of weights.
318+
# w_curr: Weights actively used for the current microbatch steps.
319+
# w_next: Newly gathered weights that will be carried forward as the new w_curr.
309320
bsw = (w_curr, w_next)
310-
p_remat_weight_prefetching = functools.partial(
311-
remat_weight_prefetching,
321+
# Bind arguments to the weight prefetching function to prepare it for linear transpose
322+
p_weight_prefetching = functools.partial(
323+
model.weight_prefetching,
312324
physical_partition_spec=physical_partition_spec,
313325
loop_iteration=loop_state["loop_iteration"],
314326
)
315-
remat_weight_prefetching_t = jax.linear_transpose(
316-
p_remat_weight_prefetching,
327+
# Since weight gathering (all-gather) is a linear operation, we can derive its dual
328+
# (reduce-scatter) via jax.linear_transpose. This avoids redundant forward passes
329+
weight_prefetching_t = jax.linear_transpose(
330+
p_weight_prefetching,
317331
pipeline_weights,
318332
)
319-
(loop_state, bsw), scan_fn_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
320-
w_curr, w_next = bsw
321-
return (loop_state, w_next), (scan_fn_vjp, remat_weight_prefetching_t)
322-
323-
def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
333+
# Execute the forward pass of the microbatches and generate its VJP.
334+
# The VJP captures necessary checkpoints to evaluate gradients later.
335+
(loop_state, bsw), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
336+
# Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337+
_, w_next = bsw
338+
return (loop_state, w_next), (scan_microbatches_vjp, weight_prefetching_t)
339+
340+
def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
341+
# Unpack forward pass residuals (VJP closures) and the incoming output gradients
324342
g_loop_state, g_w_next = g_outputs
325-
scan_fn_vjp, remat_weight_prefetching_t = residuals
343+
scan_microbatches_vjp, weight_prefetching_t = residuals
344+
# Initialize zero cotangents for w_curr, as it was consumed in the forward pass
326345
g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next)
327346
g_bsw = (g_w_curr, g_w_next)
328-
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
347+
# Backpropagate gradients through the dual microbatch execution block
348+
g_loop_state, g_bsw, _, _ = scan_microbatches_vjp((g_loop_state, g_bsw))
349+
# Apply the linear transpose of the weight prefetch to execute the reduce-scatter
350+
# This maps the gradients of the gathered weights back to the FSDP-sharded parameter space
329351
g_w_curr, g_w_next = g_bsw
330-
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
331-
return (g_loop_state, g_w_curr), g_pipeline_weights
332-
333-
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
352+
(g_pipeline_weights,) = weight_prefetching_t(g_w_next)
353+
# Return gradients corresponding to the three original inputs of execute_pipeline_stage_pure
354+
return g_loop_state, g_w_curr, g_pipeline_weights
334355

335-
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
356+
execute_pipeline_stage_pure.defvjp(execute_pipeline_stage_pure_fwd, execute_pipeline_stage_pure_bwd)
357+
# Execute the pure pipeline stage. We unpack the two modified outputs (loop_state, w_next)
358+
# and repack them alongside the unmodified pipeline_weights to maintain a consistent carry shape for nn.scan.
359+
return (*execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights), pipeline_weights), None
336360

337-
return execute_pipeline_stage_outer
361+
return execute_pipeline_stage_flax
338362

339363

340-
def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
341-
"""Wraps the pipeline stage execution in a `flax.linen.scan`.
364+
def create_flax_pipeline_scan(pipeline_stage_fn, length, remat_policy, use_scan=True):
365+
"""Wraps the pipeline stage execution in `flax.linen.remat` and `flax.linen.scan`.
342366
343-
This lifts the pipeline stage function so it can be repeated sequentially over
344-
the specified length. It safely handles Flax-specific state collections, ensuring
345-
that metrics, intermediate values, and PRNG keys do not collide or overwrite
346-
each other across the loop iterations.
367+
This explicitly wraps the pipeline step in a gradient checkpointing policy
368+
and then lifts it so it can be repeated sequentially over the specified length.
369+
It safely handles Flax-specific state collections, ensuring that metrics, intermediate
370+
values, and PRNG keys do not collide or overwrite each other across loop iterations.
347371
348372
Args:
349373
pipeline_stage_fn: The function representing a single pipeline stage
350-
(usually created by `create_rematerialized_pipeline_stage`).
374+
(usually created by `create_pipeline_stage`).
375+
remat_policy: The checkpointing policy used by `nn.remat` to manage activation memory.
351376
length: The total number of pipeline stages/repeats to scan over.
352-
use_scan: Either scan over repeats or unroll the scan.
377+
use_scan: Whether to use `jax.lax.scan` (True) or unroll the loop (False).
353378
354379
Returns:
355380
A Flax scanned function that executes the full pipeline schedule.
356381
"""
357382
unroll_length = 1 if use_scan else length
358383
return nn.scan(
359-
pipeline_stage_fn,
384+
nn.remat(
385+
pipeline_stage_fn,
386+
policy=remat_policy,
387+
),
360388
variable_axes={
361389
"summaries": 0,
362390
"aux_loss": 0,

0 commit comments

Comments
 (0)