Skip to content

Commit 9354355

Browse files
committed
Formatting
Signed-off-by: Kunjan Patel <kunjan@ucla.edu>
1 parent db66db1 commit 9354355

5 files changed

Lines changed: 98 additions & 76 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Tuple
20+
from typing import Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
24-
from jaxtyping import PyTree
24+
from jaxtyping import PyTree
2525
import orbax.checkpoint
2626
from maxdiffusion import max_logging
2727
from etils import epath
@@ -137,7 +137,7 @@ def load_params_from_path(
137137
unboxed_abstract_params,
138138
checkpoint_item: str,
139139
step: Optional[int] = None,
140-
checkpoint_item_config: Optional[str] = None
140+
checkpoint_item_config: Optional[str] = None,
141141
):
142142
ckptr = ocp.PyTreeCheckpointer()
143143

@@ -153,11 +153,7 @@ def load_params_from_path(
153153

154154
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
155155
restored = ckptr.restore(
156-
ckpt_path,
157-
item={"params": unboxed_abstract_params},
158-
transforms={},
159-
restore_args={
160-
"params": restore_args}
156+
ckpt_path, item={"params": unboxed_abstract_params}, transforms={}, restore_args={"params": restore_args}
161157
)
162158
return restored["params"]
163159

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 77 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import jax
2121
import numpy as np
22-
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager, load_params_from_path)
22+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2323
from ..pipelines.wan.wan_pipeline import WanPipeline
2424
from .. import max_logging, max_utils
2525
import orbax.checkpoint as ocp
@@ -57,18 +57,16 @@ def load_wan_configs_from_orbax(self, step):
5757
return None
5858
max_logging.log(f"Loading WAN checkpoint from step {step}")
5959
metadatas = self.checkpoint_manager.item_metadata(step)
60-
60+
6161
transformer_metadata = metadatas.wan_state
62-
abstract_tree_structure_params = jax.tree_util.tree_map(
63-
ocp.utils.to_shape_dtype_struct, transformer_metadata
64-
)
62+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
6563
params_restore = ocp.args.PyTreeRestore(
6664
restore_args=jax.tree.map(
6765
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
6866
abstract_tree_structure_params,
6967
)
7068
)
71-
69+
7270
max_logging.log("Restoring WAN checkpoint")
7371
restored_checkpoint = self.checkpoint_manager.restore(
7472
directory=epath.Path(self.config.checkpoint_dir),
@@ -77,7 +75,7 @@ def load_wan_configs_from_orbax(self, step):
7775
wan_state=params_restore,
7876
# wan_state=params_restore_util_way,
7977
wan_config=ocp.args.JsonRestore(),
80-
),
78+
),
8179
)
8280
return restored_checkpoint
8381

@@ -96,14 +94,16 @@ def load_checkpoint(self, step=None):
9694
pipeline = self.load_diffusers_checkpoint()
9795

9896
return pipeline
99-
97+
10098
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
10199
"""Saves the training state and model configurations."""
100+
102101
def config_to_json(model_or_config):
103102
return json.loads(model_or_config.to_json_string())
103+
104104
max_logging.log(f"Saving checkpoint for step {train_step}")
105105
items = {
106-
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
106+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107107
}
108108

109109
items["wan_state"] = ocp.args.PyTreeSave(train_states)
@@ -112,54 +112,72 @@ def config_to_json(model_or_config):
112112
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
113113
max_logging.log(f"Checkpoint for step {train_step} saved.")
114114

115-
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
116-
"""Saves the training state and model configurations."""
117-
def config_to_json(model_or_config):
118-
"""
119-
only save the config that is needed and can be serialized to JSON.
120-
"""
121-
if not hasattr(model_or_config, "config"):
122-
return None
123-
source_config = dict(model_or_config.config)
124-
125-
# 1. configs that can be serialized to JSON
126-
SAFE_KEYS = [
127-
'_class_name', '_diffusers_version', 'model_type', 'patch_size',
128-
'num_attention_heads', 'attention_head_dim', 'in_channels',
129-
'out_channels', 'text_dim', 'freq_dim', 'ffn_dim', 'num_layers',
130-
'cross_attn_norm', 'qk_norm', 'eps', 'image_dim',
131-
'added_kv_proj_dim', 'rope_max_seq_len', 'pos_embed_seq_len',
132-
'flash_min_seq_length', 'flash_block_sizes', 'attention',
133-
'_use_default_values'
134-
]
135-
136-
# 2. save the config that are in the SAFE_KEYS list
137-
clean_config = {}
138-
for key in SAFE_KEYS:
139-
if key in source_config:
140-
clean_config[key] = source_config[key]
141-
142-
# 3. deal with special data type and precision
143-
if 'dtype' in source_config and hasattr(source_config['dtype'], 'name'):
144-
clean_config['dtype'] = source_config['dtype'].name # e.g 'bfloat16'
145-
146-
if 'weights_dtype' in source_config and hasattr(source_config['weights_dtype'], 'name'):
147-
clean_config['weights_dtype'] = source_config['weights_dtype'].name
148-
149-
if 'precision' in source_config and isinstance(source_config['precision'], Precision):
150-
clean_config['precision'] = source_config['precision'].name # e.g. 'HIGHEST'
151-
152-
return clean_config
153-
154-
items_to_save = {
155-
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
156-
}
157-
158-
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
159-
160-
# Create CompositeArgs for Orbax
161-
save_args = ocp.args.Composite(**items_to_save)
162115

163-
# Save the checkpoint
164-
self.checkpoint_manager.save(train_step, args=save_args)
165-
max_logging.log(f"Checkpoint for step {train_step} saved.")
116+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
117+
"""Saves the training state and model configurations."""
118+
119+
def config_to_json(model_or_config):
120+
"""
121+
only save the config that is needed and can be serialized to JSON.
122+
"""
123+
if not hasattr(model_or_config, "config"):
124+
return None
125+
source_config = dict(model_or_config.config)
126+
127+
# 1. configs that can be serialized to JSON
128+
SAFE_KEYS = [
129+
"_class_name",
130+
"_diffusers_version",
131+
"model_type",
132+
"patch_size",
133+
"num_attention_heads",
134+
"attention_head_dim",
135+
"in_channels",
136+
"out_channels",
137+
"text_dim",
138+
"freq_dim",
139+
"ffn_dim",
140+
"num_layers",
141+
"cross_attn_norm",
142+
"qk_norm",
143+
"eps",
144+
"image_dim",
145+
"added_kv_proj_dim",
146+
"rope_max_seq_len",
147+
"pos_embed_seq_len",
148+
"flash_min_seq_length",
149+
"flash_block_sizes",
150+
"attention",
151+
"_use_default_values",
152+
]
153+
154+
# 2. save the config that are in the SAFE_KEYS list
155+
clean_config = {}
156+
for key in SAFE_KEYS:
157+
if key in source_config:
158+
clean_config[key] = source_config[key]
159+
160+
# 3. deal with special data type and precision
161+
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
162+
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'
163+
164+
if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
165+
clean_config["weights_dtype"] = source_config["weights_dtype"].name
166+
167+
if "precision" in source_config and isinstance(source_config["precision"]):
168+
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'
169+
170+
return clean_config
171+
172+
items_to_save = {
173+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
174+
}
175+
176+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
177+
178+
# Create CompositeArgs for Orbax
179+
save_args = ocp.args.Composite(**items_to_save)
180+
181+
# Save the checkpoint
182+
self.checkpoint_manager.save(train_step, args=save_args)
183+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
def run(config, pipeline=None, filename_prefix=""):
2727
print("seed: ", config.seed)
2828
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
29+
2930
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
3031
pipeline = checkpoint_loader.load_checkpoint()
3132
if pipeline is None:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
6666

6767

6868
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
69-
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None):
69+
def create_sharded_logical_transformer(
70+
devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None
71+
):
7072

7173
def create_model(rngs: nnx.Rngs, wan_config: dict):
7274
wan_transformer = WanModel(**wan_config, rngs=rngs)
@@ -110,7 +112,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
110112
)
111113
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
112114
for path, val in flax.traverse_util.flatten_dict(params).items():
113-
if restored_checkpoint:
115+
if restored_checkpoint:
114116
path = path[:-1]
115117
sharding = logical_state_sharding[path].value
116118
state[path].value = device_put_replicated(val, sharding)
@@ -303,9 +305,13 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
303305
return quantized_model
304306

305307
@classmethod
306-
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None):
308+
def load_transformer(
309+
cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None
310+
):
307311
with mesh:
308-
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint)
312+
wan_transformer = create_sharded_logical_transformer(
313+
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint
314+
)
309315
return wan_transformer
310316

311317
@classmethod
@@ -331,7 +337,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
331337
if not vae_only:
332338
if load_transformer:
333339
with mesh:
334-
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint)
340+
transformer = cls.load_transformer(
341+
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint
342+
)
335343

336344
text_encoder = cls.load_text_encoder(config=config)
337345
tokenizer = cls.load_tokenizer(config=config)
@@ -353,7 +361,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
353361
mesh=mesh,
354362
config=config,
355363
)
356-
364+
357365
@classmethod
358366
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
359367
devices_array = max_utils.create_device_mesh(config)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ def start_training(self):
149149

150150
pipeline = self.load_checkpoint()
151151
# Generate a sample before training to compare against generated sample after training.
152-
# UNCOMMENT
153-
# pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
152+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
154153

155154
# save some memory.
156155
del pipeline.vae
@@ -168,7 +167,7 @@ def start_training(self):
168167
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
169168

170169
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
171-
# print_ssim(pretrained_video_path, posttrained_video_path)
170+
print_ssim(pretrained_video_path, posttrained_video_path)
172171

173172
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
174173
mesh = pipeline.mesh

0 commit comments

Comments
 (0)