Skip to content

Commit 72fdd17

Browse files
committed
enabling gated attn only for connectors+a2v and v2a cross attn
1 parent 1fbcc16 commit 72fdd17

2 files changed

Lines changed: 18 additions & 15 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
rope_type=rope_type,
146146
flash_block_sizes=flash_block_sizes,
147147
flash_min_seq_length=flash_min_seq_length,
148-
gated_attn=gated_attn,
148+
gated_attn=False,
149149
)
150150

151151
self.audio_norm1 = nnx.RMSNorm(
@@ -172,7 +172,7 @@ def __init__(
172172
rope_type=rope_type,
173173
flash_block_sizes=flash_block_sizes,
174174
flash_min_seq_length=flash_min_seq_length,
175-
gated_attn=gated_attn,
175+
gated_attn=False,
176176
)
177177

178178
# 2. Prompt Cross-Attention
@@ -200,7 +200,7 @@ def __init__(
200200
attention_kernel=self.attention_kernel,
201201
rope_type=rope_type,
202202
flash_block_sizes=flash_block_sizes,
203-
gated_attn=gated_attn,
203+
gated_attn=False,
204204
)
205205

206206
self.audio_norm2 = nnx.RMSNorm(
@@ -228,7 +228,7 @@ def __init__(
228228
rope_type=rope_type,
229229
flash_block_sizes=flash_block_sizes,
230230
flash_min_seq_length=flash_min_seq_length,
231-
gated_attn=gated_attn,
231+
gated_attn=False,
232232
)
233233

234234
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -257,7 +257,7 @@ def __init__(
257257
rope_type=rope_type,
258258
flash_block_sizes=flash_block_sizes,
259259
flash_min_seq_length=0,
260-
gated_attn=gated_attn,
260+
gated_attn=self.cross_attn_mod,
261261
)
262262

263263
self.video_to_audio_norm = nnx.RMSNorm(
@@ -285,7 +285,7 @@ def __init__(
285285
rope_type=rope_type,
286286
flash_block_sizes=flash_block_sizes,
287287
flash_min_seq_length=flash_min_seq_length,
288-
gated_attn=gated_attn,
288+
gated_attn=self.cross_attn_mod,
289289
)
290290

291291
# 4. Feed Forward
@@ -1145,14 +1145,17 @@ def __call__(
11451145
)
11461146
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
11471147

1148-
# 4. Prepare prompt embeddings
11491148
if self.use_prompt_embeddings:
1150-
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
1151-
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
1149+
if self.cross_attn_mod:
1150+
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
1151+
audio_encoder_hidden_states = self.audio_caption_projection(
1152+
audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep
1153+
)
1154+
else:
1155+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
1156+
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
11521157

1153-
audio_encoder_hidden_states = self.audio_caption_projection(
1154-
audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep
1155-
)
1158+
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
11561159
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
11571160

11581161
# Construct perturbation_mask_per_layer for STG

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
132132
"caption_channels": 3840,
133133
"audio_caption_channels": 2048,
134134
"use_prompt_embeddings": False,
135-
"gated_attn": True,
135+
"gated_attn": False,
136136
"cross_attn_mod": True,
137-
"audio_gated_attn": True,
137+
"audio_gated_attn": False,
138138
"audio_cross_attn_mod": True,
139139
}
140140
else:
@@ -159,7 +159,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
159159
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
160160

161161
if getattr(config, "model_name", "") == "ltx2.3":
162-
ltx2_config["gated_attn"] = True
162+
ltx2_config["gated_attn"] = False
163163
ltx2_config["cross_attn_mod"] = True
164164
ltx2_config["perturbed_attn"] = True
165165
ltx2_config["use_prompt_embeddings"] = False

0 commit comments

Comments
 (0)