Skip to content

Commit 3296ec1

Browse files
committed
add unroll repeat scan
1 parent df09a7a commit 3296ec1

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

src/maxtext/utils/pipeline_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,13 @@ def execute_pipeline_stage(model, loop_state_and_bsw):
311311
del w_curr
312312
return (loop_state, w_next), None
313313

314-
return nn.remat(
315-
execute_pipeline_stage,
316-
prevent_cse=not model.config.scan_pipeline_iterations,
317-
policy=model.get_pipeline_remat_policy(),
318-
)
314+
return execute_pipeline_stage
315+
316+
# return nn.remat(
317+
# execute_pipeline_stage,
318+
# prevent_cse=not model.config.scan_pipeline_iterations,
319+
# policy=model.get_pipeline_remat_policy(),
320+
# )
319321

320322

321323
def create_flax_pipeline_scan(pipeline_stage_fn, length):
@@ -344,4 +346,5 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
344346
},
345347
split_rngs={"random": True},
346348
length=length,
349+
unroll=length,
347350
)

0 commit comments

Comments
 (0)