3737from skimage .metrics import structural_similarity as ssim
3838from flax .training import train_state
3939
40+
4041class TrainState (train_state .TrainState ):
4142 graphdef : nnx .GraphDef
4243 rest_of_state : nnx .State
4344
45+
4446def _to_array (x ):
4547 if not isinstance (x , jax .Array ):
4648 x = jnp .asarray (x )
4749 return x
4850
51+
4952def generate_sample (config , pipeline , filename_prefix ):
5053 """
5154 Generates a video to validate training did not corrupt the model
@@ -79,7 +82,6 @@ def __init__(self, config):
7982 if config .train_text_encoder :
8083 raise ValueError ("this script currently doesn't support training text_encoders" )
8184
82- #self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
8385 self .global_batch_size = config .per_device_batch_size * jax .device_count ()
8486
8587 def post_training_steps (self , pipeline , params , train_states , msg = "" ):
@@ -95,13 +97,10 @@ def create_scheduler(self):
9597 def calculate_tflops (self , pipeline ):
9698 max_logging .log ("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0..." )
9799 return 0
98-
100+
99101 def get_data_shardings (self , mesh ):
100102 data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
101- data_sharding = {
102- "latents" : data_sharding ,
103- "encoder_hidden_states" : data_sharding
104- }
103+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
105104 return data_sharding
106105
107106 def load_dataset (self , mesh ):
@@ -167,11 +166,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
167166
168167 with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
169168 state = TrainState .create (
170- apply_fn = graphdef .apply ,
171- params = params ,
172- tx = optimizer ,
173- graphdef = graphdef ,
174- rest_of_state = rest_of_state
169+ apply_fn = graphdef .apply , params = params , tx = optimizer , graphdef = graphdef , rest_of_state = rest_of_state
175170 )
176171 state = jax .tree .map (_to_array , state )
177172 state_spec = nnx .get_partition_spec (state )
@@ -196,8 +191,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
196191
197192 p_train_step = jax .jit (
198193 functools .partial (train_step , scheduler = pipeline .scheduler , config = self .config ),
199- in_shardings = (state_shardings , data_shardings , None , None ),
200- out_shardings = (state_shardings , None , None , None ),
194+ in_shardings = (state_shardings , data_shardings , None , None ),
195+ out_shardings = (state_shardings , None , None , None ),
201196 donate_argnums = (0 ,),
202197 )
203198 rng = jax .random .key (self .config .seed )
@@ -284,6 +279,7 @@ def loss_fn(params):
284279 loss = jnp .mean (loss )
285280
286281 return loss
282+
287283 grad_fn = nnx .value_and_grad (loss_fn )
288284 loss , grads = grad_fn (state .params )
289285 new_state = state .apply_gradients (grads = grads )
0 commit comments