Skip to content

Commit 66c85fe

Browse files
committed
add shardings to projection and patch embedding.
1 parent deb686d commit 66c85fe

3 files changed

Lines changed: 20 additions & 3 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ logical_axis_rules: [
136136
['norm', 'tensor'],
137137
['conv_batch', ['data','fsdp']],
138138
['out_channels', 'tensor'],
139-
['conv_in', 'fsdp'],
139+
['conv_out', 'fsdp'],
140140
]
141141
data_sharding: [['data', 'fsdp', 'tensor']]
142142

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ def __init__(
171171
dtype=dtype,
172172
param_dtype=weights_dtype,
173173
precision=precision,
174+
kernel_init=nnx.with_partitioning(
175+
nnx.initializers.xavier_uniform(),
176+
(
177+
"mlp",
178+
"embed",
179+
),
180+
),
174181
)
175182

176183
def __call__(self, x: jax.Array) -> jax.Array:
@@ -374,6 +381,16 @@ def __init__(
374381
dtype=dtype,
375382
param_dtype=weights_dtype,
376383
precision=precision,
384+
kernel_init=nnx.with_partitioning(
385+
nnx.initializers.xavier_uniform(),
386+
(
387+
None,
388+
None,
389+
None,
390+
None,
391+
"conv_out"
392+
),
393+
),
377394
)
378395

379396
# 2. Condition embeddings

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, config):
8080
raise ValueError("this script currently doesn't support training text_encoders")
8181

8282
#self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
83-
self.global_batch_size = config.global_batch_size if config.global_batch_size > 0 else config.per_device_batch_size * jax.device_count()
83+
self.global_batch_size = config.per_device_batch_size * jax.device_count()
8484

8585
def post_training_steps(self, pipeline, params, train_states, msg=""):
8686
pass
@@ -97,7 +97,7 @@ def calculate_tflops(self, pipeline):
9797
return 0
9898

9999
def get_data_shardings(self, mesh):
100-
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding[0]))
100+
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
101101
data_sharding = {
102102
"latents" : data_sharding,
103103
"encoder_hidden_states" : data_sharding

0 commit comments

Comments
 (0)