Skip to content

Commit fb12602

Browse files
committed
wip - add dropout change sharding
1 parent 955bd86 commit fb12602

8 files changed

Lines changed: 121 additions & 75 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ jit_initializers: True
5656
# Set true to load weights from pytorch
5757
from_pt: True
5858
split_head_dim: True
59-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
59+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6060
flash_min_seq_length: 4096
61+
dropout: 0.1
6162

6263
flash_block_sizes: {}
6364
# Use on v6e

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,18 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81+
8182
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
8283
def _make_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
84+
config,
85+
dataloading_host_index,
86+
dataloading_host_count,
87+
mesh,
88+
global_batch_size,
89+
feature_description_fn,
90+
prepare_sample_fn,
91+
dataset_path,
92+
is_training: bool,
8493
):
8594
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
8695
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
@@ -93,10 +102,10 @@ def _make_tfrecord_iterator(
93102
# Determine whether to use the "cached" dataset, which requires externally
94103
# provided parsing functions, or the default one with its internal parsing logic.
95104
make_cached_tfrecord_iterator = (
96-
config.cache_latents_text_encoder_outputs
97-
and is_dataset_dir_valid
98-
and "load_tfrecord_cached" in config.get_keys()
99-
and config.load_tfrecord_cached
105+
config.cache_latents_text_encoder_outputs
106+
and is_dataset_dir_valid
107+
and "load_tfrecord_cached" in config.get_keys()
108+
and config.load_tfrecord_cached
100109
)
101110

102111
feature_description = {
@@ -121,42 +130,47 @@ def prepare_sample(features):
121130
if not is_training:
122131
num_eval_samples = 0
123132
for _ in ds:
124-
num_eval_samples += 1
133+
num_eval_samples += 1
125134

126135
remainder = num_eval_samples % global_batch_size
127136
if remainder != 0:
128-
num_to_pad = global_batch_size - remainder
129-
# Create a dataset of padding samples from the beginning
130-
padding_ds = ds.take(num_to_pad)
131-
# Add the padding samples to the end
132-
ds = ds.concatenate(padding_ds)
133-
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
137+
num_to_pad = global_batch_size - remainder
138+
# Create a dataset of padding samples from the beginning
139+
padding_ds = ds.take(num_to_pad)
140+
# Add the padding samples to the end
141+
ds = ds.concatenate(padding_ds)
142+
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
134143

135144
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136145
ds = (
137-
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
138-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
139-
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
146+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
147+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
148+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
140149
)
141150
if is_training:
142151
ds = (
143-
ds.shuffle(global_batch_size * 10)
144-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
145-
.repeat(-1)
146-
.prefetch(AUTOTUNE)
152+
ds.shuffle(global_batch_size * 10)
153+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
154+
.repeat(-1)
155+
.prefetch(AUTOTUNE)
147156
)
148157
# For Evaluation
149158
else:
150-
ds = (
151-
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
152-
.prefetch(AUTOTUNE)
153-
)
159+
ds = ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False).prefetch(AUTOTUNE)
154160

155161
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
156162
return iter
157163

164+
158165
def make_tfrecord_iterator(
159-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
166+
config,
167+
dataloading_host_index,
168+
dataloading_host_count,
169+
mesh,
170+
global_batch_size,
171+
feature_description,
172+
prepare_sample_fn,
173+
is_training,
160174
):
161175
"""Iterator for TFRecord format. For Laion dataset,
162176
check out preparation script
@@ -165,4 +179,14 @@ def make_tfrecord_iterator(
165179
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166180
# TODO: refactor to support evaluation on all dataset format.
167181
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
168-
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)
182+
return _make_tfrecord_iterator(
183+
config,
184+
dataloading_host_index,
185+
dataloading_host_count,
186+
mesh,
187+
global_batch_size,
188+
feature_description,
189+
prepare_sample_fn,
190+
dataset_path,
191+
is_training,
192+
)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def make_data_iterator(
107107
global_batch_size,
108108
feature_description,
109109
prepare_sample_fn,
110-
is_training
110+
is_training,
111111
)
112112
else:
113113
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/models/attention_flax.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def __init__(
734734
# None axes corresponds to the stacked weights across all blocks
735735
# because of the use of nnx.vmap and nnx.scan.
736736
# Dims are [num_blocks, embed, heads]
737-
kernel_axes = (None, "embed", "heads")
737+
kernel_axes = ("embed", None, "heads")
738738
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
739739

740740
self.query = nnx.Linear(
@@ -748,8 +748,8 @@ def __init__(
748748
bias_init=nnx.with_partitioning(
749749
nnx.initializers.zeros,
750750
(
751-
None,
752751
"embed",
752+
"heads",
753753
),
754754
),
755755
)
@@ -765,8 +765,8 @@ def __init__(
765765
bias_init=nnx.with_partitioning(
766766
nnx.initializers.zeros,
767767
(
768-
None,
769768
"embed",
769+
"heads",
770770
),
771771
),
772772
)
@@ -782,8 +782,8 @@ def __init__(
782782
bias_init=nnx.with_partitioning(
783783
nnx.initializers.zeros,
784784
(
785-
None,
786785
"embed",
786+
"heads"
787787
),
788788
),
789789
)
@@ -792,12 +792,21 @@ def __init__(
792792
rngs=rngs,
793793
in_features=self.inner_dim,
794794
out_features=self.inner_dim,
795-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
795+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads", None)),
796796
dtype=dtype,
797797
param_dtype=weights_dtype,
798798
precision=precision,
799+
bias_init=nnx.with_partitioning(
800+
nnx.initializers.zeros,
801+
(
802+
"embed",
803+
None
804+
),
805+
),
799806
)
800807

808+
self.drop_out = nnx.Dropout(dropout)
809+
801810
self.norm_q = None
802811
self.norm_k = None
803812
if qk_norm is not None:
@@ -847,7 +856,8 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
847856
return xq_out, xk_out
848857

849858
def __call__(
850-
self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None
859+
self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None,
860+
deterministic: bool = True, rngs: nnx.Rngs = None,
851861
) -> jax.Array:
852862
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
853863
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
@@ -877,6 +887,7 @@ def __call__(
877887
attn_output = attn_output.astype(dtype=dtype)
878888
attn_output = checkpoint_name(attn_output, "attn_output")
879889
hidden_states = self.proj_attn(attn_output)
890+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
880891
return hidden_states
881892

882893

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,11 @@ def __init__(
175175
kernel_init=nnx.with_partitioning(
176176
nnx.initializers.xavier_uniform(),
177177
(
178+
"embed",
178179
None,
179180
"mlp",
180-
"embed",
181181
),
182182
),
183-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
184183
)
185184

186185
def __call__(self, x: jax.Array) -> jax.Array:
@@ -217,6 +216,8 @@ def __init__(
217216
else:
218217
raise NotImplementedError(f"{activation_fn} is not implemented.")
219218

219+
self.drop_out = nnx.Dropout(dropout)
220+
220221
self.proj_out = nnx.Linear(
221222
rngs=rngs,
222223
in_features=inner_dim,
@@ -228,15 +229,16 @@ def __init__(
228229
kernel_init=nnx.with_partitioning(
229230
nnx.initializers.xavier_uniform(),
230231
(
231-
None,
232232
"embed",
233233
"mlp",
234+
None,
234235
),
235236
),
236237
)
237238

238-
def __call__(self, hidden_states: jax.Array) -> jax.Array:
239+
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
239240
hidden_states = self.act_fn(hidden_states)
241+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
240242
return self.proj_out(hidden_states)
241243

242244

@@ -260,6 +262,7 @@ def __init__(
260262
weights_dtype: jnp.dtype = jnp.float32,
261263
precision: jax.lax.Precision = None,
262264
attention: str = "dot_product",
265+
dropout: float = 0.0,
263266
):
264267

265268
# 1. Self-attention
@@ -278,6 +281,7 @@ def __init__(
278281
weights_dtype=weights_dtype,
279282
precision=precision,
280283
attention_kernel=attention,
284+
dropout=dropout
281285
)
282286

283287
# 1. Cross-attention
@@ -295,6 +299,7 @@ def __init__(
295299
weights_dtype=weights_dtype,
296300
precision=precision,
297301
attention_kernel=attention,
302+
dropout=dropout
298303
)
299304
assert cross_attn_norm is True
300305
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -308,13 +313,16 @@ def __init__(
308313
dtype=dtype,
309314
weights_dtype=weights_dtype,
310315
precision=precision,
316+
dropout=dropout
311317
)
312318
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
313319

314320
key = rngs.params()
315-
self.adaln_scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5)
321+
self.adaln_scale_shift_table = nnx.Param(
322+
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
323+
sharding=("embed",))
316324

317-
def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array):
325+
def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,):
318326
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
319327
(self.adaln_scale_shift_table + temb), 6, axis=1
320328
)
@@ -324,18 +332,18 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
324332
# 1. Self-attention
325333
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
326334
attn_output = self.attn1(
327-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb
335+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb, deterministic=deterministic, rngs=rngs
328336
)
329337
hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype)
330338

331339
# 2. Cross-attention
332340
norm_hidden_states = self.norm2(hidden_states)
333-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
341+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs)
334342
hidden_states = hidden_states + attn_output
335343

336344
# 3. Feed-forward
337345
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
338-
ff_output = self.ffn(norm_hidden_states)
346+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
339347
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype)
340348
return hidden_states
341349

@@ -356,6 +364,7 @@ def __init__(
356364
freq_dim: int = 256,
357365
ffn_dim: int = 13824,
358366
num_layers: int = 40,
367+
dropout: float = 0.0,
359368
cross_attn_norm: bool = True,
360369
qk_norm: Optional[str] = "rms_norm_across_heads",
361370
eps: float = 1e-6,
@@ -424,6 +433,7 @@ def init_block(rngs):
424433
weights_dtype=weights_dtype,
425434
precision=precision,
426435
attention=attention,
436+
dropout=dropout,
427437
)
428438

429439
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -454,6 +464,8 @@ def __call__(
454464
encoder_hidden_states_image: Optional[jax.Array] = None,
455465
return_dict: bool = True,
456466
attention_kwargs: Optional[Dict[str, Any]] = None,
467+
deterministic: bool = True,
468+
rngs: nnx.Rngs = None,
457469
) -> Union[jax.Array, Dict[str, jax.Array]]:
458470
batch_size, _, num_frames, height, width = hidden_states.shape
459471
p_t, p_h, p_w = self.config.patch_size
@@ -476,20 +488,21 @@ def __call__(
476488
raise NotImplementedError("img2vid is not yet implemented.")
477489

478490
def scan_fn(carry, block):
479-
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry
480-
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
481-
return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
491+
hidden_states_carry, rngs_carry = carry
492+
hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry)
493+
new_carry = (hidden_states, rngs_carry)
494+
return new_carry, None
482495

483-
initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
484496
rematted_block_forward = self.gradient_checkpoint.apply(scan_fn)
485-
final_carry = nnx.scan(
497+
initial_carry = (hidden_states, rngs)
498+
final_carry, _ = nnx.scan(
486499
rematted_block_forward,
487500
length=self.num_layers,
488501
in_axes=(nnx.Carry, 0),
489-
out_axes=nnx.Carry,
502+
out_axes=(nnx.Carry, 0),
490503
)(initial_carry, self.blocks)
491504

492-
hidden_states = final_carry[0]
505+
hidden_states, _ = final_carry
493506

494507
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
495508

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
8282
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
8383
wan_config["remat_policy"] = config.remat_policy
8484
wan_config["flash_min_seq_length"] = config.flash_min_seq_length
85+
wan_config["dropout"] = config.dropout
8586

8687
# 2. eval_shape - will not use flops or create weights on device
8788
# thus not using HBM memory.

0 commit comments

Comments
 (0)