Skip to content

Commit b90584c

Browse files
wan training for single frame + bug fixes.
1 parent a60d235 commit b90584c

5 files changed

Lines changed: 52 additions & 53 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from abc import ABC
18+
import jax
1819
from flax import nnx
1920
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2021
from ..pipelines.wan.wan_pipeline import WanPipeline
@@ -35,20 +36,14 @@ def __init__(self, config, checkpoint_type):
3536
dataset_type=config.dataset_type
3637
)
3738

38-
# @nnx.jit
3939
def _create_optimizer(self, model, config, learning_rate):
4040
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
4141
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
4242
)
4343
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
44-
# tx = nnx.Optimizer(model, tx)
45-
46-
# _, state, rest_of_state = nnx.split((model, tx), ...)
47-
# nnx.update((model, tx), state, rest_of_state)
48-
49-
5044
return nnx.Optimizer(model, tx), learning_rate_scheduler
5145

46+
5247
def load_wan_configs_from_orbax(self, step):
5348
max_logging.log("Restoring stable diffusion configs")
5449
if step is None:

src/maxdiffusion/generate_wan.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ def run(config):
2929
slg_layers = config.slg_layers
3030
slg_start = config.slg_start
3131
slg_end = config.slg_end
32-
32+
33+
prompt = [config.prompt] * jax.device_count()
34+
negative_prompt= [config.negative_prompt] * jax.device_count()
35+
3336
videos = pipeline(
34-
prompt=config.prompt,
35-
negative_prompt=config.negative_prompt,
37+
prompt=prompt,
38+
negative_prompt=negative_prompt,
3639
height=config.height,
3740
width=config.width,
3841
num_frames=config.num_frames,
@@ -45,12 +48,12 @@ def run(config):
4548

4649
print("compile time: ", (time.perf_counter() - s0))
4750
for i in range(len(videos)):
48-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16)
51+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
4952
s0 = time.perf_counter()
5053
with jax.profiler.trace("/tmp/trace/"):
5154
videos = pipeline(
52-
prompt=config.prompt,
53-
negative_prompt=config.negative_prompt,
55+
prompt=prompt,
56+
negative_prompt=negative_prompt,
5457
height=config.height,
5558
width=config.width,
5659
num_frames=config.num_frames,
@@ -62,7 +65,7 @@ def run(config):
6265
)
6366
print("generation time: ", (time.perf_counter() - s0))
6467
for i in range(len(videos)):
65-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16)
68+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
6669

6770

6871
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,32 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
286286

287287
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)
288288

289+
def get_dummy_wan_inputs(config, pipeline, batch_size):
290+
latents = pipeline.prepare_latents(
291+
batch_size,
292+
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
293+
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
294+
height=config.height,
295+
width=config.width,
296+
num_frames=config.num_frames,
297+
num_channels_latents=pipeline.transformer.config.in_channels
298+
)
299+
bsz = latents.shape[0]
300+
prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096))
301+
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
302+
return (latents, prompt_embeds, timesteps)
303+
304+
def calculate_wan_tflops(config, pipeline, batch_size, rngs, train):
305+
"""
306+
Calculates jflux tflops.
307+
batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't
308+
cache the compilation when flash is enabled.
309+
"""
310+
(latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size)
311+
return max_utils.calculate_model_tflops(
312+
pipeline.transformer,
313+
314+
)
289315

290316
def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
291317
"""

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
5858
return vs
5959

6060

61-
partial(nnx.jit, static_argnums=(3,))
61+
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
6262
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
6363

6464
def create_model(rngs: nnx.Rngs, wan_config: dict):
@@ -106,16 +106,15 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
106106
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
107107
return wan_transformer
108108

109-
110-
partial(nnx.jit, static_argnums=(1,))
109+
@nnx.jit(static_argnums=(1,), donate_argnums=(0,))
111110
def create_sharded_logical_model(model, logical_axis_rules):
112111
graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...)
113112
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules)
114113
state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
115114
pspecs = nnx.get_partition_spec(state)
116115
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
117-
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
118-
return wan_transformer
116+
model = nnx.merge(graphdef, sharded_state, rest_of_state)
117+
return model
119118

120119

121120
class WanPipeline:
@@ -473,9 +472,8 @@ def transformer_forward_pass(
473472
encoder_hidden_states=prompt_embeds,
474473
is_uncond=is_uncond,
475474
slg_mask=slg_mask
476-
)[0]
475+
)
477476

478-
#@partial(jax.jit, static_argnums=(6, 7, 8))
479477
def run_inference(
480478
graphdef,
481479
sharded_state,

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax.tree_util as jtu
2424
from flax import nnx
2525
from ..schedulers import FlaxEulerDiscreteScheduler
26-
from .. import max_utils, max_logging, train_utils
26+
from .. import max_utils, max_logging, train_utils, maxdiffusion_utils
2727
from ..checkpointing.wan_checkpointer import (
2828
WanCheckpointer,
2929
WAN_CHECKPOINT
@@ -64,36 +64,15 @@ def load_dataset(self, pipeline):
6464
# prompt embeds shape: (1, 512, 4096)
6565
# For now, we will pass the same latents over and over
6666
# TODO - create a dataset
67-
prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (self.global_batch_size, 512, 4096))
68-
latents = pipeline.prepare_latents(
69-
self.global_batch_size,
70-
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
71-
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
72-
height=self.config.height,
73-
width=self.config.width,
74-
num_frames=self.config.num_frames,
75-
num_channels_latents=pipeline.transformer.config.in_channels
76-
)
77-
return (latents, prompt_embeds)
67+
return maxdiffusion_utils.get_dummy_wan_inputs(self.config, pipeline, self.global_batch_size)
7868

7969
def start_training(self):
8070

8171
pipeline = self.load_checkpoint()
82-
mesh = pipeline.mesh
83-
84-
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, self.config.learning_rate)
85-
86-
# @nnx.jit
87-
# def create_transformer_state(transformer):
88-
# optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate)
89-
# breakpoint()
90-
# _, state = nnx.split((transformer, optimizer))
91-
92-
# with mesh:
93-
# create_transformer_state(pipeline.transformer)
94-
95-
#graphdef, state = nnx.plit((pipeline.transformer, optimizer))
72+
del pipeline.vae
9673
dummy_inputs = self.load_dataset(pipeline)
74+
mesh = pipeline.mesh
75+
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
9776
dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs])
9877
self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs)
9978

@@ -116,7 +95,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
11695
state = state.to_pure_dict()
11796
p_train_step = jax.jit(
11897
train_step,
119-
donate_argnums=(1,),
98+
donate_argnums=(0,),
12099
)
121100
rng = jax.random.key(self.config.seed)
122101
start_step = 0
@@ -137,7 +116,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
137116
if self.config.enable_profiler and step == first_profiling_step:
138117
max_utils.activate_profiler(self.config)
139118
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh:
140-
state, train_metric, rng = p_train_step(graphdef, state, data, rng)
119+
state, train_metric, rng = p_train_step(state, graphdef, data, rng)
141120

142121
new_time = datetime.datetime.now()
143122

@@ -151,15 +130,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
151130
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
152131
last_step_completion = new_time
153132

154-
def train_step(graphdef, state, data, rng):
133+
def train_step(state, graphdef, data, rng):
155134
return step_optimizer(graphdef, state, data, rng)
156135

157136
def step_optimizer(graphdef, state, data, rng):
158137
_, new_rng = jax.random.split(rng)
159138
def loss_fn(model):
160-
latents, prompt_embeds = data
161-
bsz = latents.shape[0]
162-
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
139+
latents, prompt_embeds, timesteps = data
163140

164141
noise = jax.random.normal(
165142
key=new_rng,

0 commit comments

Comments
 (0)