@@ -145,7 +145,7 @@ def __init__(
145145 rope_type = rope_type ,
146146 flash_block_sizes = flash_block_sizes ,
147147 flash_min_seq_length = flash_min_seq_length ,
148- gated_attn = gated_attn ,
148+ gated_attn = False ,
149149 )
150150
151151 self .audio_norm1 = nnx .RMSNorm (
@@ -172,7 +172,7 @@ def __init__(
172172 rope_type = rope_type ,
173173 flash_block_sizes = flash_block_sizes ,
174174 flash_min_seq_length = flash_min_seq_length ,
175- gated_attn = gated_attn ,
175+ gated_attn = False ,
176176 )
177177
178178 # 2. Prompt Cross-Attention
@@ -200,7 +200,7 @@ def __init__(
200200 attention_kernel = self .attention_kernel ,
201201 rope_type = rope_type ,
202202 flash_block_sizes = flash_block_sizes ,
203- gated_attn = gated_attn ,
203+ gated_attn = False ,
204204 )
205205
206206 self .audio_norm2 = nnx .RMSNorm (
@@ -228,7 +228,7 @@ def __init__(
228228 rope_type = rope_type ,
229229 flash_block_sizes = flash_block_sizes ,
230230 flash_min_seq_length = flash_min_seq_length ,
231- gated_attn = gated_attn ,
231+ gated_attn = False ,
232232 )
233233
234234 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -257,7 +257,7 @@ def __init__(
257257 rope_type = rope_type ,
258258 flash_block_sizes = flash_block_sizes ,
259259 flash_min_seq_length = 0 ,
260- gated_attn = gated_attn ,
260+ gated_attn = self . cross_attn_mod ,
261261 )
262262
263263 self .video_to_audio_norm = nnx .RMSNorm (
@@ -285,7 +285,7 @@ def __init__(
285285 rope_type = rope_type ,
286286 flash_block_sizes = flash_block_sizes ,
287287 flash_min_seq_length = flash_min_seq_length ,
288- gated_attn = gated_attn ,
288+ gated_attn = self . cross_attn_mod ,
289289 )
290290
291291 # 4. Feed Forward
@@ -1145,14 +1145,17 @@ def __call__(
11451145 )
11461146 audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
11471147
1148- # 4. Prepare prompt embeddings
11491148 if self .use_prompt_embeddings :
1150- encoder_hidden_states = self .caption_projection (encoder_hidden_states , timestep )
1151- encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
1149+ if self .cross_attn_mod :
1150+ encoder_hidden_states = self .caption_projection (encoder_hidden_states , timestep )
1151+ audio_encoder_hidden_states = self .audio_caption_projection (
1152+ audio_encoder_hidden_states , audio_timestep if audio_timestep is not None else timestep
1153+ )
1154+ else :
1155+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
1156+ audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
11521157
1153- audio_encoder_hidden_states = self .audio_caption_projection (
1154- audio_encoder_hidden_states , audio_timestep if audio_timestep is not None else timestep
1155- )
1158+ encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
11561159 audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
11571160
11581161 # Construct perturbation_mask_per_layer for STG
0 commit comments