2323import jax .tree_util as jtu
2424from flax import nnx
2525from ..schedulers import FlaxEulerDiscreteScheduler
26- from .. import max_utils , max_logging , train_utils
26+ from .. import max_utils , max_logging , train_utils , maxdiffusion_utils
2727from ..checkpointing .wan_checkpointer import (
2828 WanCheckpointer ,
2929 WAN_CHECKPOINT
@@ -64,36 +64,15 @@ def load_dataset(self, pipeline):
6464 # prompt embeds shape: (1, 512, 4096)
6565 # For now, we will pass the same latents over and over
6666 # TODO - create a dataset
67- prompt_embeds = jax .random .normal (jax .random .key (self .config .seed ), (self .global_batch_size , 512 , 4096 ))
68- latents = pipeline .prepare_latents (
69- self .global_batch_size ,
70- vae_scale_factor_temporal = pipeline .vae_scale_factor_temporal ,
71- vae_scale_factor_spatial = pipeline .vae_scale_factor_spatial ,
72- height = self .config .height ,
73- width = self .config .width ,
74- num_frames = self .config .num_frames ,
75- num_channels_latents = pipeline .transformer .config .in_channels
76- )
77- return (latents , prompt_embeds )
67+ return maxdiffusion_utils .get_dummy_wan_inputs (self .config , pipeline , self .global_batch_size )
7868
7969 def start_training (self ):
8070
8171 pipeline = self .load_checkpoint ()
82- mesh = pipeline .mesh
83-
84- optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , self .config .learning_rate )
85-
86- # @nnx.jit
87- # def create_transformer_state(transformer):
88- # optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate)
89- # breakpoint()
90- # _, state = nnx.split((transformer, optimizer))
91-
92- # with mesh:
93- # create_transformer_state(pipeline.transformer)
94-
95- #graphdef, state = nnx.plit((pipeline.transformer, optimizer))
72+ del pipeline .vae
9673 dummy_inputs = self .load_dataset (pipeline )
74+ mesh = pipeline .mesh
75+ optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
9776 dummy_inputs = tuple ([jtu .tree_map_with_path (functools .partial (_form_global_array , global_mesh = mesh ), input ) for input in dummy_inputs ])
9877 self .training_loop (pipeline , optimizer , learning_rate_scheduler , dummy_inputs )
9978
@@ -116,7 +95,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
11695 state = state .to_pure_dict ()
11796 p_train_step = jax .jit (
11897 train_step ,
119- donate_argnums = (1 ,),
98+ donate_argnums = (0 ,),
12099 )
121100 rng = jax .random .key (self .config .seed )
122101 start_step = 0
@@ -137,7 +116,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
137116 if self .config .enable_profiler and step == first_profiling_step :
138117 max_utils .activate_profiler (self .config )
139118 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ), pipeline .mesh :
140- state , train_metric , rng = p_train_step (graphdef , state , data , rng )
119+ state , train_metric , rng = p_train_step (state , graphdef , data , rng )
141120
142121 new_time = datetime .datetime .now ()
143122
@@ -151,15 +130,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
151130 train_utils .write_metrics (writer , local_metrics_file , running_gcs_metrics , train_metric , step , self .config )
152131 last_step_completion = new_time
153132
154- def train_step (graphdef , state , data , rng ):
133+ def train_step (state , graphdef , data , rng ):
155134 return step_optimizer (graphdef , state , data , rng )
156135
157136def step_optimizer (graphdef , state , data , rng ):
158137 _ , new_rng = jax .random .split (rng )
159138 def loss_fn (model ):
160- latents , prompt_embeds = data
161- bsz = latents .shape [0 ]
162- timesteps = jnp .array ([0 ] * bsz , dtype = jnp .int32 )
139+ latents , prompt_embeds , timesteps = data
163140
164141 noise = jax .random .normal (
165142 key = new_rng ,
0 commit comments