@@ -398,6 +398,7 @@ def __init__(
398398 continuous_eval_timeout : int = 30 ,
399399 rng_seed : int = core .DEFAULT_RNG_SEED ,
400400 rng_impl : str | None = None ,
401+ enable_checkpointing : bool = True ,
401402 ):
402403 """Initializes the instance.
403404
@@ -436,6 +437,7 @@ def __init__(
436437 rng_impl: The implementation of the PRNG key. By default this is set to
437438 None which means that the default implementation (generally
438439 partitionable threefry) will be used.
440+ enable_checkpointing: Whether to enable checkpointing. Defaults to True.
439441 """
440442
441443 if not isinstance (steps_per_loop , int ) or steps_per_loop < 1 :
@@ -453,6 +455,7 @@ def __init__(
453455 self ._max_checkpoints_to_keep = max_checkpoints_to_keep
454456 self ._rng_impl = rng_impl
455457 self ._rng_seed = rng_seed
458+ self ._enable_checkpointing = enable_checkpointing
456459
457460 @functools .cached_property
458461 def checkpoint_manager (self ) -> ocp .CheckpointManager :
@@ -467,14 +470,19 @@ def checkpoint_manager(self) -> ocp.CheckpointManager:
467470 save_on_steps .append (self ._train_steps - 1 )
468471
469472 save_on_steps = set (save_on_steps )
470-
471- return ocp .CheckpointManager (
472- directory = os .path .join (self ._model_dir , core .CHECKPOINT_DIR ),
473- options = ocp .CheckpointManagerOptions (
474- should_save_fn = lambda step , _ : step in save_on_steps ,
475- max_to_keep = self ._max_checkpoints_to_keep ,
476- ),
477- )
473+
474+ if self ._enable_checkpointing :
475+
476+ return ocp .CheckpointManager (
477+ directory = os .path .join (self ._model_dir , core .CHECKPOINT_DIR ),
478+ options = ocp .CheckpointManagerOptions (
479+ should_save_fn = lambda step , _ : step in save_on_steps ,
480+ max_to_keep = self ._max_checkpoints_to_keep ,
481+ ),
482+ )
483+ else :
484+
485+ return None
478486
479487 @functools .cached_property
480488 def train_summary_writer (self ) -> metrics_tools .AsyncMultiWriter :
@@ -510,6 +518,9 @@ def _maybe_save_checkpoint(
510518 metrics : Mapping [str , Any ] | None = None ,
511519 ):
512520 """Saves a checkpoint and returns a bool indicating whether it was saved."""
521+ if not self ._enable_checkpointing :
522+ return
523+
513524 items = {core .STATE_CHECKPOINT_KEY : ocp .args .StandardSave (state )}
514525 with self .report_progress .timed ("checkpointing" ):
515526 self .checkpoint_manager .save (
@@ -564,7 +575,7 @@ def _train_n_steps(
564575 state , metrics_update = train_step (inputs , state )
565576 metrics_accum .accumulate (metrics_update , step )
566577 self .report_progress (step )
567- if step != start_step + num_steps - 1 :
578+ if ( step != start_step + num_steps - 1 ) and self . _enable_checkpointing :
568579 self ._maybe_save_checkpoint (step , state )
569580
570581 metrics = metrics_accum .compute_and_log_scalars (start_step + num_steps - 1 )
@@ -651,6 +662,7 @@ def _eval_step(
651662
652663 if (
653664 check_for_checkpoints
665+ and self ._enable_checkpointing
654666 and self .checkpoint_manager .latest_step () is not None
655667 ):
656668 step_to_resume_from = self .checkpoint_manager .latest_step ()
@@ -674,7 +686,7 @@ def _eval_step(
674686 def train (self , task : JaxTask ) -> core .Logs :
675687 """Trains the model."""
676688 train_iter , _ , state , train_step , _ , step = self .process_task (
677- task , training = True , check_for_checkpoints = True
689+ task , training = True , check_for_checkpoints = False
678690 )
679691
680692 logging .info (
@@ -698,25 +710,27 @@ def train(self, task: JaxTask) -> core.Logs:
698710 f" { _format_output (train_metrics )} "
699711 )
700712 metrics [core .TRAIN_LOG_DIRNAME ] = train_metrics
701-
702- self ._maybe_save_checkpoint (curr_step , state , metrics = metrics )
713+ if self . _enable_checkpointing :
714+ self ._maybe_save_checkpoint (curr_step , state , metrics = metrics )
703715 step = curr_step + 1
704716
705- self .checkpoint_manager .wait_until_finished ()
717+ if self ._enable_checkpointing :
718+ self .checkpoint_manager .wait_until_finished ()
706719
707720 if jax .process_index () == 0 :
708721 self ._write_marker_file ()
709722 task .export_model (state , self ._model_dir )
710723
711- self .checkpoint_manager .close ()
712- del self .checkpoint_manager
724+ if self ._enable_checkpointing :
725+ self .checkpoint_manager .close ()
726+ del self .checkpoint_manager
713727
714728 return metrics
715729
716730 def evaluate (self , task : JaxTask ) -> core .Logs :
717731 """Evaluates the model."""
718732 _ , eval_iters , state , _ , eval_step , step = self .process_task (
719- task , training = False , check_for_checkpoints = True
733+ task , training = False , check_for_checkpoints = False
720734 )
721735 eval_summary_writers = self ._create_eval_summary_writers (eval_iters )
722736
@@ -749,7 +763,7 @@ def evaluate(self, task: JaxTask) -> core.Logs:
749763 def train_and_evaluate (self , task : JaxTask ) -> core .Logs :
750764 """Trains and evaluates the model."""
751765 train_iter , eval_iters , state , train_step , eval_step , step = (
752- self .process_task (task , training = True , check_for_checkpoints = True )
766+ self .process_task (task , training = True , check_for_checkpoints = False )
753767 )
754768 eval_summary_writers = self ._create_eval_summary_writers (eval_iters )
755769
@@ -794,18 +808,20 @@ def train_and_evaluate(self, task: JaxTask) -> core.Logs:
794808 f" { _format_output (eval_metrics )} "
795809 )
796810 metrics [_val_logdir (key )] = eval_metrics
797-
798- self ._maybe_save_checkpoint (curr_step , state , metrics = metrics )
811+ if self . _enable_checkpointing :
812+ self ._maybe_save_checkpoint (curr_step , state , metrics = metrics )
799813 step = curr_step + 1
800814
801- self .checkpoint_manager .wait_until_finished ()
815+ if self ._enable_checkpointing :
816+ self .checkpoint_manager .wait_until_finished ()
802817
803818 if jax .process_index () == 0 :
804819 self ._write_marker_file ()
805820 task .export_model (state , self ._model_dir )
806821
807- self .checkpoint_manager .close ()
808- del self .checkpoint_manager
822+ if self ._enable_checkpointing :
823+ self .checkpoint_manager .close ()
824+ del self .checkpoint_manager
809825
810826 return metrics
811827
@@ -833,7 +849,8 @@ def timeout_fn() -> bool:
833849 timeout_fn = timeout_fn ,
834850 ):
835851 try :
836- state = self ._maybe_restore_checkpoint (state , step )
852+ if self ._enable_checkpointing :
853+ state = self ._maybe_restore_checkpoint (state , step )
837854 logging .info (f"eval | step: { step : 6d} | { steps_msg } " )
838855 with self .report_progress .timed ("eval" ):
839856 for key , eval_iter in eval_iters .items ():
@@ -930,3 +947,4 @@ def _format_output(output: Any, indent: int = 4, width: int = 80) -> str:
930947 return formatted
931948 lines = [" " * indent + line for line in lines ]
932949 return "\n " + "\n " .join (lines )
950+
0 commit comments