Skip to content

Commit c39f1b8

Browse files
committed
remove args.
1 parent 57ead0b commit c39f1b8

1 file changed

Lines changed: 97 additions & 106 deletions

File tree

src/diffusers/models/transformers/transformer_ltx2.py

Lines changed: 97 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -459,49 +459,43 @@ def forward(
459459
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
460460
a2v_cross_attention_mask: Optional[torch.Tensor] = None,
461461
v2a_cross_attention_mask: Optional[torch.Tensor] = None,
462-
use_video_self_attn: bool = True,
463-
use_audio_self_attn: bool = True,
464-
use_a2v_cross_attn: bool = True,
465-
use_v2a_cross_attn: bool = True,
466462
) -> torch.Tensor:
467463
batch_size = hidden_states.size(0)
468464

469465
# 1. Video and Audio Self-Attention
470-
if use_video_self_attn:
471-
norm_hidden_states = self.norm1(hidden_states)
466+
norm_hidden_states = self.norm1(hidden_states)
472467

473-
num_ada_params = self.scale_shift_table.shape[0]
474-
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
475-
batch_size, temb.size(1), num_ada_params, -1
476-
)
477-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
478-
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
468+
num_ada_params = self.scale_shift_table.shape[0]
469+
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
470+
batch_size, temb.size(1), num_ada_params, -1
471+
)
472+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
473+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
479474

480-
attn_hidden_states = self.attn1(
481-
hidden_states=norm_hidden_states,
482-
encoder_hidden_states=None,
483-
query_rotary_emb=video_rotary_emb,
484-
)
485-
hidden_states = hidden_states + attn_hidden_states * gate_msa
475+
attn_hidden_states = self.attn1(
476+
hidden_states=norm_hidden_states,
477+
encoder_hidden_states=None,
478+
query_rotary_emb=video_rotary_emb,
479+
)
480+
hidden_states = hidden_states + attn_hidden_states * gate_msa
486481

487-
if use_audio_self_attn:
488-
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
482+
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
489483

490-
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
491-
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
492-
batch_size, temb_audio.size(1), num_audio_ada_params, -1
493-
)
494-
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
495-
audio_ada_values.unbind(dim=2)
496-
)
497-
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
484+
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
485+
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
486+
batch_size, temb_audio.size(1), num_audio_ada_params, -1
487+
)
488+
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
489+
audio_ada_values.unbind(dim=2)
490+
)
491+
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
498492

499-
attn_audio_hidden_states = self.audio_attn1(
500-
hidden_states=norm_audio_hidden_states,
501-
encoder_hidden_states=None,
502-
query_rotary_emb=audio_rotary_emb,
503-
)
504-
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
493+
attn_audio_hidden_states = self.audio_attn1(
494+
hidden_states=norm_audio_hidden_states,
495+
encoder_hidden_states=None,
496+
query_rotary_emb=audio_rotary_emb,
497+
)
498+
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
505499

506500
# 2. Video and Audio Cross-Attention with the text embeddings
507501
norm_hidden_states = self.norm2(hidden_states)
@@ -523,80 +517,77 @@ def forward(
523517
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
524518

525519
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
526-
if use_a2v_cross_attn or use_v2a_cross_attn:
527-
norm_hidden_states = self.audio_to_video_norm(hidden_states)
528-
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
529-
530-
# Combine global and per-layer cross attention modulation parameters
531-
# Video
532-
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
533-
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
534-
535-
video_ca_scale_shift_table = (
536-
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
537-
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
538-
).unbind(dim=2)
539-
video_ca_gate = (
540-
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
541-
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
542-
).unbind(dim=2)
543-
544-
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
545-
a2v_gate = video_ca_gate[0].squeeze(2)
546-
547-
# Audio
548-
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
549-
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
550-
551-
audio_ca_scale_shift_table = (
552-
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
553-
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
554-
).unbind(dim=2)
555-
audio_ca_gate = (
556-
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
557-
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
558-
).unbind(dim=2)
559-
560-
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
561-
v2a_gate = audio_ca_gate[0].squeeze(2)
562-
563-
if use_a2v_cross_attn:
564-
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
565-
mod_norm_hidden_states = norm_hidden_states * (
566-
1 + video_a2v_ca_scale.squeeze(2)
567-
) + video_a2v_ca_shift.squeeze(2)
568-
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
569-
1 + audio_a2v_ca_scale.squeeze(2)
570-
) + audio_a2v_ca_shift.squeeze(2)
571-
572-
a2v_attn_hidden_states = self.audio_to_video_attn(
573-
mod_norm_hidden_states,
574-
encoder_hidden_states=mod_norm_audio_hidden_states,
575-
query_rotary_emb=ca_video_rotary_emb,
576-
key_rotary_emb=ca_audio_rotary_emb,
577-
attention_mask=a2v_cross_attention_mask,
578-
)
520+
norm_hidden_states = self.audio_to_video_norm(hidden_states)
521+
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
522+
523+
# Combine global and per-layer cross attention modulation parameters
524+
# Video
525+
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
526+
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
527+
528+
video_ca_scale_shift_table = (
529+
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
530+
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
531+
).unbind(dim=2)
532+
video_ca_gate = (
533+
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
534+
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
535+
).unbind(dim=2)
536+
537+
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
538+
a2v_gate = video_ca_gate[0].squeeze(2)
539+
540+
# Audio
541+
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
542+
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
543+
544+
audio_ca_scale_shift_table = (
545+
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
546+
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
547+
).unbind(dim=2)
548+
audio_ca_gate = (
549+
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
550+
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
551+
).unbind(dim=2)
552+
553+
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
554+
v2a_gate = audio_ca_gate[0].squeeze(2)
555+
556+
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
557+
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
558+
2
559+
)
560+
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
561+
1 + audio_a2v_ca_scale.squeeze(2)
562+
) + audio_a2v_ca_shift.squeeze(2)
563+
564+
a2v_attn_hidden_states = self.audio_to_video_attn(
565+
mod_norm_hidden_states,
566+
encoder_hidden_states=mod_norm_audio_hidden_states,
567+
query_rotary_emb=ca_video_rotary_emb,
568+
key_rotary_emb=ca_audio_rotary_emb,
569+
attention_mask=a2v_cross_attention_mask,
570+
)
579571

580-
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
581-
582-
if use_v2a_cross_attn:
583-
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
584-
mod_norm_hidden_states = norm_hidden_states * (
585-
1 + video_v2a_ca_scale.squeeze(2)
586-
) + video_v2a_ca_shift.squeeze(2)
587-
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
588-
1 + audio_v2a_ca_scale.squeeze(2)
589-
) + audio_v2a_ca_shift.squeeze(2)
590-
591-
v2a_attn_hidden_states = self.video_to_audio_attn(
592-
mod_norm_audio_hidden_states,
593-
encoder_hidden_states=mod_norm_hidden_states,
594-
query_rotary_emb=ca_audio_rotary_emb,
595-
key_rotary_emb=ca_video_rotary_emb,
596-
attention_mask=v2a_cross_attention_mask,
597-
)
572+
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
573+
574+
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
575+
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
576+
2
577+
)
578+
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
579+
1 + audio_v2a_ca_scale.squeeze(2)
580+
) + audio_v2a_ca_shift.squeeze(2)
581+
582+
v2a_attn_hidden_states = self.video_to_audio_attn(
583+
mod_norm_audio_hidden_states,
584+
encoder_hidden_states=mod_norm_hidden_states,
585+
query_rotary_emb=ca_audio_rotary_emb,
586+
key_rotary_emb=ca_video_rotary_emb,
587+
attention_mask=v2a_cross_attention_mask,
588+
)
598589

599-
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
590+
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
600591

601592
# 4. Feedforward
602593
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp

0 commit comments

Comments
 (0)