1- # Copyright 2023–2025 Google LLC
1+ # Copyright 2023–2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -53,13 +53,9 @@ def checkpoint_loop(config, state=None):
5353 """
5454 model = from_config (config )
5555 mesh = model .mesh
56- init_rng , checkpoint_manager , _ , tx = train_utils .create_training_tools (
57- config , model , mesh
58- )
56+ init_rng , checkpoint_manager , _ , tx = train_utils .create_training_tools (config , model , mesh )
5957
60- unboxed_abstract_state , _ , _ = maxtext_utils .get_abstract_state (
61- model , tx , config , init_rng , mesh , is_training = True
62- )
58+ unboxed_abstract_state , _ , _ = maxtext_utils .get_abstract_state (model , tx , config , init_rng , mesh , is_training = True )
6359 # A barrier to sync all hosts before starting to restore checkpoint
6460 jax .experimental .multihost_utils .sync_global_devices ("Barrier before load" )
6561 checkpoint_load_start = datetime .datetime .now ()
@@ -82,30 +78,24 @@ def checkpoint_loop(config, state=None):
8278 if state is not None : # Checkpoint was available for restore
8379 if jax .process_index () == 0 :
8480 max_logging .log (
85- "STANDALONE CHECKPOINTER : Checkpoint restored in :"
86- f" { checkpoint_load_end - checkpoint_load_start } "
81+ "STANDALONE CHECKPOINTER : Checkpoint restored in :" f" { checkpoint_load_end - checkpoint_load_start } "
8782 )
8883 else : # Checkpoint was unavailable, state needs to be initialized
89- state , _ , _ , _ = maxtext_utils .setup_training_state (
90- model , None , tx , config , init_rng , mesh , checkpoint_manager
91- )
84+ state , _ , _ , _ = maxtext_utils .setup_training_state (model , None , tx , config , init_rng , mesh , checkpoint_manager )
9285 state = add_entropy_to_checkpoint (state )
9386
9487 start_step = get_first_step (state ) # this is the start_step for training
9588 for step in np .arange (start_step , config .steps ):
9689 if checkpoint_manager is not None :
9790 start_time = datetime .datetime .now ()
9891 # A barrier to sync all hosts before starting to save checkpoint
99- jax .experimental .multihost_utils .sync_global_devices (
100- "Barrier before save"
101- )
92+ jax .experimental .multihost_utils .sync_global_devices ("Barrier before save" )
10293 if checkpointing .save_checkpoint (checkpoint_manager , int (step ), state ):
10394 checkpoint_manager .wait_until_finished ()
10495 end_time = datetime .datetime .now ()
10596 if jax .process_index () == 0 :
10697 max_logging .log (
107- "STANDALONE CHECKPOINTER : Checkpoint saved in"
108- f" { end_time - start_time } ,step { step } , on host 0"
98+ "STANDALONE CHECKPOINTER : Checkpoint saved in" f" { end_time - start_time } ,step { step } , on host 0"
10999 )
110100
111101 return state
@@ -123,12 +113,8 @@ def add_entropy_to_checkpoint(state):
123113 state: Returns state with entropy added to the optimizer state.
124114 """
125115 opt_0 = state .opt_state [0 ]
126- opt_0 = opt_0 ._replace (
127- mu = jax .tree_util .tree_map (lambda k : jnp .cos (1000 * k ), state .params )
128- )
129- opt_0 = opt_0 ._replace (
130- nu = jax .tree_util .tree_map (lambda k : jnp .sin (1000 * k ), state .params )
131- )
116+ opt_0 = opt_0 ._replace (mu = jax .tree_util .tree_map (lambda k : jnp .cos (1000 * k ), state .params ))
117+ opt_0 = opt_0 ._replace (nu = jax .tree_util .tree_map (lambda k : jnp .sin (1000 * k ), state .params ))
132118 new_opt = [opt_0 ] + list (state .opt_state [1 :])
133119 state = state .replace (opt_state = new_opt )
134120 return state
0 commit comments