Skip to content

Commit d3b50c8

Browse files
committed
using train state instead.
1 parent 00bf1cd commit d3b50c8

3 files changed

Lines changed: 49 additions & 31 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _create_optimizer(self, model, config, learning_rate):
4242
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
4343
)
4444
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
45-
return nnx.Optimizer(model, tx), learning_rate_scheduler
45+
return tx, learning_rate_scheduler
4646

4747
def load_wan_configs_from_orbax(self, step):
4848
max_logging.log("Restoring stable diffusion configs")

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,16 +374,6 @@ def __init__(
374374
dtype=dtype,
375375
param_dtype=weights_dtype,
376376
precision=precision,
377-
kernel_init=nnx.with_partitioning(
378-
nnx.initializers.xavier_uniform(),
379-
(
380-
None,
381-
None,
382-
None,
383-
None,
384-
"conv_out",
385-
),
386-
),
387377
)
388378

389379
# 2. Condition embeddings

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tensorflow as tf
2424
import jax.numpy as jnp
2525
import jax
26+
from jax.sharding import PartitionSpec as P
2627
from flax import nnx
2728
from maxdiffusion.schedulers import FlaxFlowMatchScheduler
2829
from flax.linen import partitioning as nn_partitioning
@@ -34,7 +35,16 @@
3435
from maxdiffusion.video_processor import VideoProcessor
3536
from maxdiffusion.utils import load_video
3637
from skimage.metrics import structural_similarity as ssim
38+
from flax.training import train_state
3739

40+
class TrainState(train_state.TrainState):
41+
graphdef: nnx.GraphDef
42+
rest_of_state: nnx.State
43+
44+
def _to_array(x):
45+
if not isinstance(x, jax.Array):
46+
x = jnp.asarray(x)
47+
return x
3848

3949
def generate_sample(config, pipeline, filename_prefix):
4050
"""
@@ -85,6 +95,14 @@ def create_scheduler(self):
8595
def calculate_tflops(self, pipeline):
8696
max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...")
8797
return 0
98+
99+
def get_data_shardings(self, mesh):
100+
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
101+
data_sharding = {
102+
"latents" : data_sharding,
103+
"encoder_hidden_states" : data_sharding
104+
}
105+
return data_sharding
88106

89107
def load_dataset(self, mesh):
90108
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
@@ -136,24 +154,36 @@ def start_training(self):
136154
scheduler, scheduler_state = self.create_scheduler()
137155
pipeline.scheduler = scheduler
138156
pipeline.scheduler_state = scheduler_state
139-
140157
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
141-
142158
# Returns pipeline with trained transformer state
143159
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
144160

145161
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
146162
print_ssim(pretrained_video_path, posttrained_video_path)
147163

148164
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
149-
150-
graphdef, state = nnx.split((pipeline.transformer, optimizer))
165+
mesh = pipeline.mesh
166+
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
167+
168+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
169+
state = TrainState.create(
170+
apply_fn=graphdef.apply,
171+
params=params,
172+
tx=optimizer,
173+
graphdef=graphdef,
174+
rest_of_state=rest_of_state
175+
)
176+
state = jax.tree.map(_to_array, state)
177+
state_spec = nnx.get_partition_spec(state)
178+
state = jax.lax.with_sharding_constraint(state, state_spec)
179+
state_shardings = nnx.get_named_sharding(state, mesh)
180+
data_shardings = self.get_data_shardings(mesh)
151181

152182
writer = max_utils.initialize_summary_writer(self.config)
153183
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
154184
writer_thread.start()
155185

156-
num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0])
186+
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
157187
max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer)
158188
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer)
159189
max_utils.add_config_to_summary_writer(self.config, writer)
@@ -164,9 +194,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
164194
max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}")
165195
max_logging.log(f" Total optimization steps = {self.config.max_train_steps}")
166196

167-
state = state.to_pure_dict()
168197
p_train_step = jax.jit(
169198
functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config),
199+
in_shardings = (state_shardings, data_shardings, None, None),
200+
out_shardings = (state_shardings, None, None, None),
170201
donate_argnums=(0,),
171202
)
172203
rng = jax.random.key(self.config.seed)
@@ -195,7 +226,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
195226
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
196227
self.config.logical_axis_rules
197228
):
198-
state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, example_batch, rng)
229+
state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state)
199230
train_metric["scalar"]["learning/loss"].block_until_ready()
200231
last_step_completion = datetime.datetime.now()
201232

@@ -215,19 +246,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
215246
writer.flush()
216247

217248
# load new state for trained tranformer
218-
graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
219-
pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state)
249+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
220250
return pipeline
221251

222252

223-
def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config):
224-
return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config)
253+
def train_step(state, data, rng, scheduler_state, scheduler, config):
254+
return step_optimizer(state, data, rng, scheduler_state, scheduler, config)
225255

226256

227-
def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config):
257+
def step_optimizer(state, data, rng, scheduler_state, scheduler, config):
228258
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
229259

230-
def loss_fn(model):
260+
def loss_fn(params):
261+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
231262
latents = data["latents"].astype(config.weights_dtype)
232263
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
233264
bsz = latents.shape[0]
@@ -253,11 +284,8 @@ def loss_fn(model):
253284
loss = jnp.mean(loss)
254285

255286
return loss
256-
257-
model, optimizer = nnx.merge(graphdef, state)
258-
loss, grads = nnx.value_and_grad(loss_fn)(model)
259-
optimizer.update(grads)
260-
state = nnx.state((model, optimizer))
261-
state = state.to_pure_dict()
287+
grad_fn = nnx.value_and_grad(loss_fn)
288+
loss, grads = grad_fn(state.params)
289+
new_state = state.apply_gradients(grads=grads)
262290
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
263-
return state, scheduler_state, metrics, new_rng
291+
return new_state, scheduler_state, metrics, new_rng

0 commit comments

Comments
 (0)