1919
2020import jax
2121import numpy as np
22- from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager , load_params_from_path )
22+ from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
2323from ..pipelines .wan .wan_pipeline import WanPipeline
2424from .. import max_logging , max_utils
2525import orbax .checkpoint as ocp
@@ -57,18 +57,16 @@ def load_wan_configs_from_orbax(self, step):
5757 return None
5858 max_logging .log (f"Loading WAN checkpoint from step { step } " )
5959 metadatas = self .checkpoint_manager .item_metadata (step )
60-
60+
6161 transformer_metadata = metadatas .wan_state
62- abstract_tree_structure_params = jax .tree_util .tree_map (
63- ocp .utils .to_shape_dtype_struct , transformer_metadata
64- )
62+ abstract_tree_structure_params = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , transformer_metadata )
6563 params_restore = ocp .args .PyTreeRestore (
6664 restore_args = jax .tree .map (
6765 lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
6866 abstract_tree_structure_params ,
6967 )
7068 )
71-
69+
7270 max_logging .log ("Restoring WAN checkpoint" )
7371 restored_checkpoint = self .checkpoint_manager .restore (
7472 directory = epath .Path (self .config .checkpoint_dir ),
@@ -77,7 +75,7 @@ def load_wan_configs_from_orbax(self, step):
7775 wan_state = params_restore ,
7876 # wan_state=params_restore_util_way,
7977 wan_config = ocp .args .JsonRestore (),
80- ),
78+ ),
8179 )
8280 return restored_checkpoint
8381
@@ -96,14 +94,16 @@ def load_checkpoint(self, step=None):
9694 pipeline = self .load_diffusers_checkpoint ()
9795
9896 return pipeline
99-
97+
10098 def save_checkpoint (self , train_step , pipeline : WanPipeline , train_states : dict ):
10199 """Saves the training state and model configurations."""
100+
102101 def config_to_json (model_or_config ):
103102 return json .loads (model_or_config .to_json_string ())
103+
104104 max_logging .log (f"Saving checkpoint for step { train_step } " )
105105 items = {
106- "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
106+ "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
107107 }
108108
109109 items ["wan_state" ] = ocp .args .PyTreeSave (train_states )
@@ -112,54 +112,72 @@ def config_to_json(model_or_config):
112112 self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
113113 max_logging .log (f"Checkpoint for step { train_step } saved." )
114114
115- def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
116- """Saves the training state and model configurations."""
117- def config_to_json (model_or_config ):
118- """
119- only save the config that is needed and can be serialized to JSON.
120- """
121- if not hasattr (model_or_config , "config" ):
122- return None
123- source_config = dict (model_or_config .config )
124-
125- # 1. configs that can be serialized to JSON
126- SAFE_KEYS = [
127- '_class_name' , '_diffusers_version' , 'model_type' , 'patch_size' ,
128- 'num_attention_heads' , 'attention_head_dim' , 'in_channels' ,
129- 'out_channels' , 'text_dim' , 'freq_dim' , 'ffn_dim' , 'num_layers' ,
130- 'cross_attn_norm' , 'qk_norm' , 'eps' , 'image_dim' ,
131- 'added_kv_proj_dim' , 'rope_max_seq_len' , 'pos_embed_seq_len' ,
132- 'flash_min_seq_length' , 'flash_block_sizes' , 'attention' ,
133- '_use_default_values'
134- ]
135-
136- # 2. save the config that are in the SAFE_KEYS list
137- clean_config = {}
138- for key in SAFE_KEYS :
139- if key in source_config :
140- clean_config [key ] = source_config [key ]
141-
142- # 3. deal with special data type and precision
143- if 'dtype' in source_config and hasattr (source_config ['dtype' ], 'name' ):
144- clean_config ['dtype' ] = source_config ['dtype' ].name # e.g 'bfloat16'
145-
146- if 'weights_dtype' in source_config and hasattr (source_config ['weights_dtype' ], 'name' ):
147- clean_config ['weights_dtype' ] = source_config ['weights_dtype' ].name
148-
149- if 'precision' in source_config and isinstance (source_config ['precision' ], Precision ):
150- clean_config ['precision' ] = source_config ['precision' ].name # e.g. 'HIGHEST'
151-
152- return clean_config
153-
154- items_to_save = {
155- "transformer_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
156- }
157-
158- items_to_save ["transformer_states" ] = ocp .args .PyTreeSave (train_states )
159-
160- # Create CompositeArgs for Orbax
161- save_args = ocp .args .Composite (** items_to_save )
162115
163- # Save the checkpoint
164- self .checkpoint_manager .save (train_step , args = save_args )
165- max_logging .log (f"Checkpoint for step { train_step } saved." )
116+ def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
117+ """Saves the training state and model configurations."""
118+
119+ def config_to_json (model_or_config ):
120+ """
121+ only save the config that is needed and can be serialized to JSON.
122+ """
123+ if not hasattr (model_or_config , "config" ):
124+ return None
125+ source_config = dict (model_or_config .config )
126+
127+ # 1. configs that can be serialized to JSON
128+ SAFE_KEYS = [
129+ "_class_name" ,
130+ "_diffusers_version" ,
131+ "model_type" ,
132+ "patch_size" ,
133+ "num_attention_heads" ,
134+ "attention_head_dim" ,
135+ "in_channels" ,
136+ "out_channels" ,
137+ "text_dim" ,
138+ "freq_dim" ,
139+ "ffn_dim" ,
140+ "num_layers" ,
141+ "cross_attn_norm" ,
142+ "qk_norm" ,
143+ "eps" ,
144+ "image_dim" ,
145+ "added_kv_proj_dim" ,
146+ "rope_max_seq_len" ,
147+ "pos_embed_seq_len" ,
148+ "flash_min_seq_length" ,
149+ "flash_block_sizes" ,
150+ "attention" ,
151+ "_use_default_values" ,
152+ ]
153+
154+ # 2. save the config that are in the SAFE_KEYS list
155+ clean_config = {}
156+ for key in SAFE_KEYS :
157+ if key in source_config :
158+ clean_config [key ] = source_config [key ]
159+
160+ # 3. deal with special data type and precision
161+ if "dtype" in source_config and hasattr (source_config ["dtype" ], "name" ):
162+ clean_config ["dtype" ] = source_config ["dtype" ].name # e.g 'bfloat16'
163+
164+ if "weights_dtype" in source_config and hasattr (source_config ["weights_dtype" ], "name" ):
165+ clean_config ["weights_dtype" ] = source_config ["weights_dtype" ].name
166+
167+ if "precision" in source_config and isinstance (source_config ["precision" ]):
168+ clean_config ["precision" ] = source_config ["precision" ].name # e.g. 'HIGHEST'
169+
170+ return clean_config
171+
172+ items_to_save = {
173+ "transformer_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
174+ }
175+
176+ items_to_save ["transformer_states" ] = ocp .args .PyTreeSave (train_states )
177+
178+ # Create CompositeArgs for Orbax
179+ save_args = ocp .args .Composite (** items_to_save )
180+
181+ # Save the checkpoint
182+ self .checkpoint_manager .save (train_step , args = save_args )
183+ max_logging .log (f"Checkpoint for step { train_step } saved." )
0 commit comments