2323import tensorflow as tf
2424import jax .numpy as jnp
2525import jax
26+ from jax .sharding import PartitionSpec as P
2627from flax import nnx
2728from maxdiffusion .schedulers import FlaxFlowMatchScheduler
2829from flax .linen import partitioning as nn_partitioning
3435from maxdiffusion .video_processor import VideoProcessor
3536from maxdiffusion .utils import load_video
3637from skimage .metrics import structural_similarity as ssim
38+ from flax .training import train_state
3739
40+ class TrainState (train_state .TrainState ):
41+ graphdef : nnx .GraphDef
42+ rest_of_state : nnx .State
43+
44+ def _to_array (x ):
45+ if not isinstance (x , jax .Array ):
46+ x = jnp .asarray (x )
47+ return x
3848
3949def generate_sample (config , pipeline , filename_prefix ):
4050 """
@@ -85,6 +95,14 @@ def create_scheduler(self):
8595 def calculate_tflops (self , pipeline ):
8696 max_logging .log ("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0..." )
8797 return 0
98+
99+ def get_data_shardings (self , mesh ):
100+ data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
101+ data_sharding = {
102+ "latents" : data_sharding ,
103+ "encoder_hidden_states" : data_sharding
104+ }
105+ return data_sharding
88106
89107 def load_dataset (self , mesh ):
90108 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
@@ -136,24 +154,36 @@ def start_training(self):
136154 scheduler , scheduler_state = self .create_scheduler ()
137155 pipeline .scheduler = scheduler
138156 pipeline .scheduler_state = scheduler_state
139-
140157 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
141-
142158 # Returns pipeline with trained transformer state
143159 pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , data_iterator )
144160
145161 posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
146162 print_ssim (pretrained_video_path , posttrained_video_path )
147163
148164 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data_iterator ):
149-
150- graphdef , state = nnx .split ((pipeline .transformer , optimizer ))
165+ mesh = pipeline .mesh
166+ graphdef , params , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
167+
168+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
169+ state = TrainState .create (
170+ apply_fn = graphdef .apply ,
171+ params = params ,
172+ tx = optimizer ,
173+ graphdef = graphdef ,
174+ rest_of_state = rest_of_state
175+ )
176+ state = jax .tree .map (_to_array , state )
177+ state_spec = nnx .get_partition_spec (state )
178+ state = jax .lax .with_sharding_constraint (state , state_spec )
179+ state_shardings = nnx .get_named_sharding (state , mesh )
180+ data_shardings = self .get_data_shardings (mesh )
151181
152182 writer = max_utils .initialize_summary_writer (self .config )
153183 writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
154184 writer_thread .start ()
155185
156- num_model_parameters = max_utils .calculate_num_params_from_pytree (state [ 0 ] )
186+ num_model_parameters = max_utils .calculate_num_params_from_pytree (state . params )
157187 max_utils .add_text_to_summary_writer ("number_model_parameters" , str (num_model_parameters ), writer )
158188 max_utils .add_text_to_summary_writer ("libtpu_init_args" , os .environ .get ("LIBTPU_INIT_ARGS" , "" ), writer )
159189 max_utils .add_config_to_summary_writer (self .config , writer )
@@ -164,9 +194,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
164194 max_logging .log (f" Total train batch size (w. parallel & distributed) = { self .global_batch_size } " )
165195 max_logging .log (f" Total optimization steps = { self .config .max_train_steps } " )
166196
167- state = state .to_pure_dict ()
168197 p_train_step = jax .jit (
169198 functools .partial (train_step , scheduler = pipeline .scheduler , config = self .config ),
199+ in_shardings = (state_shardings , data_shardings , None , None ),
200+ out_shardings = (state_shardings , None , None , None ),
170201 donate_argnums = (0 ,),
171202 )
172203 rng = jax .random .key (self .config .seed )
@@ -195,7 +226,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
195226 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ), pipeline .mesh , nn_partitioning .axis_rules (
196227 self .config .logical_axis_rules
197228 ):
198- state , scheduler_state , train_metric , rng = p_train_step (state , graphdef , scheduler_state , example_batch , rng )
229+ state , scheduler_state , train_metric , rng = p_train_step (state , example_batch , rng , scheduler_state )
199230 train_metric ["scalar" ]["learning/loss" ].block_until_ready ()
200231 last_step_completion = datetime .datetime .now ()
201232
@@ -215,19 +246,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
215246 writer .flush ()
216247
217248 # load new state for trained tranformer
218- graphdef , _ , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
219- pipeline .transformer = nnx .merge (graphdef , state [0 ], rest_of_state )
249+ pipeline .transformer = nnx .merge (state .graphdef , state .params , state .rest_of_state )
220250 return pipeline
221251
222252
223- def train_step (state , graphdef , scheduler_state , data , rng , scheduler , config ):
224- return step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng , config )
253+ def train_step (state , data , rng , scheduler_state , scheduler , config ):
254+ return step_optimizer (state , data , rng , scheduler_state , scheduler , config )
225255
226256
227- def step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng , config ):
257+ def step_optimizer (state , data , rng , scheduler_state , scheduler , config ):
228258 _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
229259
230- def loss_fn (model ):
260+ def loss_fn (params ):
261+ model = nnx .merge (state .graphdef , params , state .rest_of_state )
231262 latents = data ["latents" ].astype (config .weights_dtype )
232263 encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
233264 bsz = latents .shape [0 ]
@@ -253,11 +284,8 @@ def loss_fn(model):
253284 loss = jnp .mean (loss )
254285
255286 return loss
256-
257- model , optimizer = nnx .merge (graphdef , state )
258- loss , grads = nnx .value_and_grad (loss_fn )(model )
259- optimizer .update (grads )
260- state = nnx .state ((model , optimizer ))
261- state = state .to_pure_dict ()
287+ grad_fn = nnx .value_and_grad (loss_fn )
288+ loss , grads = grad_fn (state .params )
289+ new_state = state .apply_gradients (grads = grads )
262290 metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
263- return state , scheduler_state , metrics , new_rng
291+ return new_state , scheduler_state , metrics , new_rng
0 commit comments