Skip to content

Commit 08e4267

Browse files
committed
add lax.cond and dtype control
1 parent 09ee677 commit 08e4267

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ logical_axis_rules: [
4141
['activation_attn_length', ['context', 'expert']],
4242
['activation_q_length', ['context', 'expert']],
4343
['activation_attn_embed', ['tensor']],
44+
['activation_norm_length', ['context']],
45+
['activation_norm_length_moe', ['context']],
4446
['activation_embed', ['tensor']],
4547
['activation_embed_moe', ['tensor']],
4648
['activation_mlp', ['tensor']],

src/maxtext/utils/pipeline_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def create_gradient_accumulation_scan(
194194
A JAX custom_vjp function that executes the `length` pipeline iterations.
195195
"""
196196

197-
@functools.partial(jax.custom_vjp)
197+
@jax.custom_vjp
198198
def run_single_microbatch_custom(lightweight_state, bsw, pos_arg, seg_arg):
199199
return run_single_microbatch_custom_fwd(lightweight_state, bsw, pos_arg, seg_arg)[0]
200200

@@ -203,10 +203,10 @@ def _run(l, b):
203203
out = model.run_one_iteration(
204204
l, b, pos_arg, seg_arg, deterministic, model_mode, logical_partition_spec=logical_partition_spec
205205
)
206-
return out, b
206+
return out
207207

208208
# Rematerialize the inner step to save activation memory
209-
_run_remat = jax.remat(_run, prevent_cse=False, policy=model.get_pipeline_remat_policy())
209+
_run_remat = jax.remat(_run, policy=model.get_pipeline_remat_policy())
210210
out, vjp_fun = jax.vjp(_run_remat, lightweight_state, bsw)
211211
return out, vjp_fun
212212

@@ -217,14 +217,14 @@ def run_single_microbatch_custom_bwd(res, g_out):
217217

218218
run_single_microbatch_custom.defvjp(run_single_microbatch_custom_fwd, run_single_microbatch_custom_bwd)
219219

220-
@functools.partial(jax.custom_vjp)
220+
@jax.custom_vjp
221221
def run_pipeline_microbatches_custom(loop_state, bsw, positions, segment_ids):
222222
return run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids)[0]
223223

224224
def run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids):
225225
final_lightweight, scan_vjp_fun = jax.vjp(
226226
lambda l, b: jax.lax.scan(
227-
lambda carry, _: (run_single_microbatch_custom(carry, b, positions, segment_ids)[0], None),
227+
lambda carry, _: (run_single_microbatch_custom(carry, b, positions, segment_ids), None),
228228
l,
229229
None,
230230
length=length,
@@ -332,9 +332,8 @@ def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
332332
)
333333
# Execute the forward pass of the microbatches and generate its VJP.
334334
# The VJP captures necessary checkpoints to evaluate gradients later.
335-
(loop_state, bsw), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
335+
(loop_state, _), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
336336
# Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337-
_, w_next = bsw
338337
return (loop_state, w_next), (scan_microbatches_vjp, weight_prefetching_t)
339338

340339
def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
@@ -354,6 +353,7 @@ def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
354353
return g_loop_state, g_w_curr, g_pipeline_weights
355354

356355
execute_pipeline_stage_pure.defvjp(execute_pipeline_stage_pure_fwd, execute_pipeline_stage_pure_bwd)
356+
357357
# Execute the pure pipeline stage. We unpack the two modified outputs (loop_state, w_next)
358358
# and repack them alongside the unmodified pipeline_weights to maintain a consistent carry shape for nn.scan.
359359
return (*execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights), pipeline_weights), None

0 commit comments

Comments
 (0)