Skip to content

Commit f0efb5c

Browse files
committed
test for multihost without scan.
1 parent deb686d commit f0efb5c

4 files changed

Lines changed: 54 additions & 52 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def __init__(
665665
# None axes corresponds to the stacked weights across all blocks
666666
# because of the use of nnx.vmap and nnx.scan.
667667
# Dims are [num_blocks, embed, heads]
668-
kernel_axes = (None, "embed", "heads")
668+
kernel_axes = ("embed", "heads")
669669
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
670670

671671
self.query = nnx.Linear(
@@ -679,7 +679,6 @@ def __init__(
679679
bias_init=nnx.with_partitioning(
680680
nnx.initializers.zeros,
681681
(
682-
None,
683682
"embed",
684683
),
685684
),
@@ -696,7 +695,6 @@ def __init__(
696695
bias_init=nnx.with_partitioning(
697696
nnx.initializers.zeros,
698697
(
699-
None,
700698
"embed",
701699
),
702700
),
@@ -713,7 +711,6 @@ def __init__(
713711
bias_init=nnx.with_partitioning(
714712
nnx.initializers.zeros,
715713
(
716-
None,
717714
"embed",
718715
),
719716
),
@@ -723,10 +720,16 @@ def __init__(
723720
rngs=rngs,
724721
in_features=self.inner_dim,
725722
out_features=self.inner_dim,
726-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
723+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
727724
dtype=dtype,
728725
param_dtype=weights_dtype,
729726
precision=precision,
727+
bias_init=nnx.with_partitioning(
728+
nnx.initializers.zeros,
729+
(
730+
"heads",
731+
),
732+
),
730733
)
731734

732735
self.norm_q = None
@@ -740,7 +743,6 @@ def __init__(
740743
scale_init=nnx.with_partitioning(
741744
nnx.initializers.ones,
742745
(
743-
None,
744746
"norm",
745747
),
746748
),
@@ -754,7 +756,6 @@ def __init__(
754756
scale_init=nnx.with_partitioning(
755757
nnx.initializers.ones,
756758
(
757-
None,
758759
"norm",
759760
),
760761
),
@@ -780,6 +781,7 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
780781
def __call__(
781782
self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None
782783
) -> jax.Array:
784+
#breakpoint()
783785
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
784786
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
785787
dtype = hidden_states.dtype
@@ -799,10 +801,12 @@ def __call__(
799801
value_proj = _unflatten_heads(value_proj, self.heads)
800802
# output of _unflatten_heads Batch, heads, seq_len, head_dim
801803
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
804+
#query_proj = query_proj
802805

803806
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
804-
807+
attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", "fsdp", "tensor"))
805808
attn_output = attn_output.astype(dtype=dtype)
809+
806810

807811
hidden_states = self.proj_attn(attn_output)
808812
return hidden_states

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,17 @@ def __init__(
171171
dtype=dtype,
172172
param_dtype=weights_dtype,
173173
precision=precision,
174+
# kernel_init=nnx.with_partitioning(
175+
# nnx.initializers.xavier_uniform(),
176+
# ("blockwise", None, None),
177+
# ),
178+
# bias_init=nnx.with_partitioning(
179+
# nnx.initializers.zeros,
180+
# (
181+
# "blockwise",
182+
# None,
183+
# ),
184+
# ),
174185
)
175186

176187
def __call__(self, x: jax.Array) -> jax.Array:
@@ -218,10 +229,18 @@ def __init__(
218229
kernel_init=nnx.with_partitioning(
219230
nnx.initializers.xavier_uniform(),
220231
(
232+
# "blockwise",
221233
"mlp",
222234
"embed",
223235
),
224236
),
237+
# bias_init=nnx.with_partitioning(
238+
# nnx.initializers.zeros,
239+
# (
240+
# "blockwise",
241+
# "mlp",
242+
# ),
243+
# ),
225244
)
226245

227246
def __call__(self, hidden_states: jax.Array) -> jax.Array:
@@ -389,10 +408,9 @@ def __init__(
389408
)
390409

391410
# 3. Transformer blocks
392-
@nnx.split_rngs(splits=num_layers)
393-
@nnx.vmap(in_axes=0, out_axes=0)
394-
def init_block(rngs):
395-
return WanTransformerBlock(
411+
blocks = []
412+
for _ in range(num_layers):
413+
block = WanTransformerBlock(
396414
rngs=rngs,
397415
dim=inner_dim,
398416
ffn_dim=ffn_dim,
@@ -408,10 +426,15 @@ def init_block(rngs):
408426
precision=precision,
409427
attention=attention,
410428
)
429+
blocks.append(block)
430+
self.blocks = blocks
411431

412-
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
432+
# 2. Use a predicate to create a "state-free" version.
433+
# The lambda function `lambda _: False` simply tells nnx.state_if
434+
# to filter out ALL state components (params, variables, etc.).
435+
# self.block_template = nnx.state_if(lambda _: False, template_block_with_state)
413436

414-
self.blocks = init_block(rngs)
437+
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
415438

416439
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
417440
self.proj_out = nnx.Linear(
@@ -426,7 +449,7 @@ def init_block(rngs):
426449
key = rngs.params()
427450
self.scale_shift_table = nnx.Param(
428451
jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5,
429-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
452+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
430453
)
431454

432455
def __call__(
@@ -456,22 +479,12 @@ def __call__(
456479

457480
if encoder_hidden_states_image is not None:
458481
raise NotImplementedError("img2vid is not yet implemented.")
459-
460-
def scan_fn(carry, block):
461-
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry
462-
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
463-
return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
464-
465-
initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
466-
rematted_block_forward = self.gradient_checkpoint.apply(scan_fn)
467-
final_carry = nnx.scan(
468-
rematted_block_forward,
469-
length=self.num_layers,
470-
in_axes=(nnx.Carry, 0),
471-
out_axes=nnx.Carry,
472-
)(initial_carry, self.blocks)
473-
474-
hidden_states = final_carry[0]
482+
483+
for block in self.blocks:
484+
def block_forward(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb):
485+
return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
486+
rematted_block_forward = self.gradient_checkpoint.apply(block_forward)
487+
hidden_states = rematted_block_forward(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
475488

476489
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
477490

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ def load_wan_transformer(
171171
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
172172

173173

174-
def load_base_wan_transformer(
175-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
176-
):
174+
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40):
177175
device = jax.local_devices(backend=device)[0]
178176
subfolder = "transformer"
179177
filename = "diffusion_pytorch_model.safetensors.index.json"
@@ -231,22 +229,9 @@ def load_base_wan_transformer(
231229
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
232230
pt_tuple_key = tuple(renamed_pt_key.split("."))
233231

234-
if "blocks" in pt_tuple_key:
235-
new_key = ("blocks",) + pt_tuple_key[2:]
236-
block_index = int(pt_tuple_key[1])
237-
pt_tuple_key = new_key
238-
flax_key, flax_tensor = rename_key_and_reshape_tensor(
239-
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
240-
)
232+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
241233
flax_key = rename_for_nnx(flax_key)
242234
flax_key = _tuple_str_to_int(flax_key)
243-
244-
if "blocks" in flax_key:
245-
if flax_key in flax_state_dict:
246-
new_tensor = flax_state_dict[flax_key]
247-
else:
248-
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
249-
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
250235
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
251236
validate_flax_state_dict(eval_shapes, flax_state_dict)
252237
flax_state_dict = unflatten_dict(flax_state_dict)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def __init__(self, config):
7979
if config.train_text_encoder:
8080
raise ValueError("this script currently doesn't support training text_encoders")
8181

82-
#self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
83-
self.global_batch_size = config.global_batch_size if config.global_batch_size > 0 else config.per_device_batch_size * jax.device_count()
82+
self.global_batch_size = config.per_device_batch_size * jax.device_count()
8483

8584
def post_training_steps(self, pipeline, params, train_states, msg=""):
8685
pass
@@ -97,7 +96,8 @@ def calculate_tflops(self, pipeline):
9796
return 0
9897

9998
def get_data_shardings(self, mesh):
100-
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding[0]))
99+
p_spec = P(*self.config.data_sharding)
100+
data_sharding = jax.sharding.NamedSharding(mesh, p_spec)
101101
data_sharding = {
102102
"latents" : data_sharding,
103103
"encoder_hidden_states" : data_sharding
@@ -143,7 +143,6 @@ def prepare_sample(features):
143143
def start_training(self):
144144

145145
pipeline = self.load_checkpoint()
146-
# del pipeline.vae
147146

148147
# Generate a sample before training to compare against generated sample after training.
149148
#pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
@@ -178,6 +177,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
178177
state = jax.lax.with_sharding_constraint(state, state_spec)
179178
state_shardings = nnx.get_named_sharding(state, mesh)
180179
data_shardings = self.get_data_shardings(mesh)
180+
#breakpoint()
181181

182182
writer = max_utils.initialize_summary_writer(self.config)
183183
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)

0 commit comments

Comments
 (0)