@@ -194,7 +194,7 @@ def create_gradient_accumulation_scan(
194194 A JAX custom_vjp function that executes the `length` pipeline iterations.
195195 """
196196
197- @functools . partial ( jax .custom_vjp )
197+ @jax .custom_vjp
198198 def run_single_microbatch_custom (lightweight_state , bsw , pos_arg , seg_arg ):
199199 return run_single_microbatch_custom_fwd (lightweight_state , bsw , pos_arg , seg_arg )[0 ]
200200
@@ -203,10 +203,10 @@ def _run(l, b):
203203 out = model .run_one_iteration (
204204 l , b , pos_arg , seg_arg , deterministic , model_mode , logical_partition_spec = logical_partition_spec
205205 )
206- return out , b
206+ return out
207207
208208 # Rematerialize the inner step to save activation memory
209- _run_remat = jax .remat (_run , prevent_cse = False , policy = model .get_pipeline_remat_policy ())
209+ _run_remat = jax .remat (_run , policy = model .get_pipeline_remat_policy ())
210210 out , vjp_fun = jax .vjp (_run_remat , lightweight_state , bsw )
211211 return out , vjp_fun
212212
@@ -217,14 +217,14 @@ def run_single_microbatch_custom_bwd(res, g_out):
217217
218218 run_single_microbatch_custom .defvjp (run_single_microbatch_custom_fwd , run_single_microbatch_custom_bwd )
219219
220- @functools . partial ( jax .custom_vjp )
220+ @jax .custom_vjp
221221 def run_pipeline_microbatches_custom (loop_state , bsw , positions , segment_ids ):
222222 return run_pipeline_microbatches_custom_fwd (loop_state , bsw , positions , segment_ids )[0 ]
223223
224224 def run_pipeline_microbatches_custom_fwd (loop_state , bsw , positions , segment_ids ):
225225 final_lightweight , scan_vjp_fun = jax .vjp (
226226 lambda l , b : jax .lax .scan (
227- lambda carry , _ : (run_single_microbatch_custom (carry , b , positions , segment_ids )[ 0 ] , None ),
227+ lambda carry , _ : (run_single_microbatch_custom (carry , b , positions , segment_ids ), None ),
228228 l ,
229229 None ,
230230 length = length ,
@@ -332,9 +332,8 @@ def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
332332 )
333333 # Execute the forward pass of the microbatches and generate its VJP.
334334 # The VJP captures necessary checkpoints to evaluate gradients later.
335- (loop_state , bsw ), scan_microbatches_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
335+ (loop_state , _ ), scan_microbatches_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
336336 # Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337- _ , w_next = bsw
338337 return (loop_state , w_next ), (scan_microbatches_vjp , weight_prefetching_t )
339338
340339 def execute_pipeline_stage_pure_bwd (residuals , g_outputs ):
@@ -354,6 +353,7 @@ def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
354353 return g_loop_state , g_w_curr , g_pipeline_weights
355354
356355 execute_pipeline_stage_pure .defvjp (execute_pipeline_stage_pure_fwd , execute_pipeline_stage_pure_bwd )
356+
357357 # Execute the pure pipeline stage. We unpack the two modified outputs (loop_state, w_next)
358358 # and repack them alongside the unmodified pipeline_weights to maintain a consistent carry shape for nn.scan.
359359 return (* execute_pipeline_stage_pure (loop_state , w_curr , pipeline_weights ), pipeline_weights ), None
0 commit comments