@@ -459,49 +459,43 @@ def forward(
459459 audio_encoder_attention_mask : Optional [torch .Tensor ] = None ,
460460 a2v_cross_attention_mask : Optional [torch .Tensor ] = None ,
461461 v2a_cross_attention_mask : Optional [torch .Tensor ] = None ,
462- use_video_self_attn : bool = True ,
463- use_audio_self_attn : bool = True ,
464- use_a2v_cross_attn : bool = True ,
465- use_v2a_cross_attn : bool = True ,
466462 ) -> torch .Tensor :
467463 batch_size = hidden_states .size (0 )
468464
469465 # 1. Video and Audio Self-Attention
470- if use_video_self_attn :
471- norm_hidden_states = self .norm1 (hidden_states )
466+ norm_hidden_states = self .norm1 (hidden_states )
472467
473- num_ada_params = self .scale_shift_table .shape [0 ]
474- ada_values = self .scale_shift_table [None , None ].to (temb .device ) + temb .reshape (
475- batch_size , temb .size (1 ), num_ada_params , - 1
476- )
477- shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = ada_values .unbind (dim = 2 )
478- norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
468+ num_ada_params = self .scale_shift_table .shape [0 ]
469+ ada_values = self .scale_shift_table [None , None ].to (temb .device ) + temb .reshape (
470+ batch_size , temb .size (1 ), num_ada_params , - 1
471+ )
472+ shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = ada_values .unbind (dim = 2 )
473+ norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
479474
480- attn_hidden_states = self .attn1 (
481- hidden_states = norm_hidden_states ,
482- encoder_hidden_states = None ,
483- query_rotary_emb = video_rotary_emb ,
484- )
485- hidden_states = hidden_states + attn_hidden_states * gate_msa
475+ attn_hidden_states = self .attn1 (
476+ hidden_states = norm_hidden_states ,
477+ encoder_hidden_states = None ,
478+ query_rotary_emb = video_rotary_emb ,
479+ )
480+ hidden_states = hidden_states + attn_hidden_states * gate_msa
486481
487- if use_audio_self_attn :
488- norm_audio_hidden_states = self .audio_norm1 (audio_hidden_states )
482+ norm_audio_hidden_states = self .audio_norm1 (audio_hidden_states )
489483
490- num_audio_ada_params = self .audio_scale_shift_table .shape [0 ]
491- audio_ada_values = self .audio_scale_shift_table [None , None ].to (temb_audio .device ) + temb_audio .reshape (
492- batch_size , temb_audio .size (1 ), num_audio_ada_params , - 1
493- )
494- audio_shift_msa , audio_scale_msa , audio_gate_msa , audio_shift_mlp , audio_scale_mlp , audio_gate_mlp = (
495- audio_ada_values .unbind (dim = 2 )
496- )
497- norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa ) + audio_shift_msa
484+ num_audio_ada_params = self .audio_scale_shift_table .shape [0 ]
485+ audio_ada_values = self .audio_scale_shift_table [None , None ].to (temb_audio .device ) + temb_audio .reshape (
486+ batch_size , temb_audio .size (1 ), num_audio_ada_params , - 1
487+ )
488+ audio_shift_msa , audio_scale_msa , audio_gate_msa , audio_shift_mlp , audio_scale_mlp , audio_gate_mlp = (
489+ audio_ada_values .unbind (dim = 2 )
490+ )
491+ norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa ) + audio_shift_msa
498492
499- attn_audio_hidden_states = self .audio_attn1 (
500- hidden_states = norm_audio_hidden_states ,
501- encoder_hidden_states = None ,
502- query_rotary_emb = audio_rotary_emb ,
503- )
504- audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
493+ attn_audio_hidden_states = self .audio_attn1 (
494+ hidden_states = norm_audio_hidden_states ,
495+ encoder_hidden_states = None ,
496+ query_rotary_emb = audio_rotary_emb ,
497+ )
498+ audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
505499
506500 # 2. Video and Audio Cross-Attention with the text embeddings
507501 norm_hidden_states = self .norm2 (hidden_states )
@@ -523,80 +517,77 @@ def forward(
523517 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
524518
525519 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
526- if use_a2v_cross_attn or use_v2a_cross_attn :
527- norm_hidden_states = self .audio_to_video_norm (hidden_states )
528- norm_audio_hidden_states = self .video_to_audio_norm (audio_hidden_states )
529-
530- # Combine global and per-layer cross attention modulation parameters
531- # Video
532- video_per_layer_ca_scale_shift = self .video_a2v_cross_attn_scale_shift_table [:4 , :]
533- video_per_layer_ca_gate = self .video_a2v_cross_attn_scale_shift_table [4 :, :]
534-
535- video_ca_scale_shift_table = (
536- video_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_scale_shift .dtype )
537- + temb_ca_scale_shift .reshape (batch_size , temb_ca_scale_shift .shape [1 ], 4 , - 1 )
538- ).unbind (dim = 2 )
539- video_ca_gate = (
540- video_per_layer_ca_gate [:, :, ...].to (temb_ca_gate .dtype )
541- + temb_ca_gate .reshape (batch_size , temb_ca_gate .shape [1 ], 1 , - 1 )
542- ).unbind (dim = 2 )
543-
544- video_a2v_ca_scale , video_a2v_ca_shift , video_v2a_ca_scale , video_v2a_ca_shift = video_ca_scale_shift_table
545- a2v_gate = video_ca_gate [0 ].squeeze (2 )
546-
547- # Audio
548- audio_per_layer_ca_scale_shift = self .audio_a2v_cross_attn_scale_shift_table [:4 , :]
549- audio_per_layer_ca_gate = self .audio_a2v_cross_attn_scale_shift_table [4 :, :]
550-
551- audio_ca_scale_shift_table = (
552- audio_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_audio_scale_shift .dtype )
553- + temb_ca_audio_scale_shift .reshape (batch_size , temb_ca_audio_scale_shift .shape [1 ], 4 , - 1 )
554- ).unbind (dim = 2 )
555- audio_ca_gate = (
556- audio_per_layer_ca_gate [:, :, ...].to (temb_ca_audio_gate .dtype )
557- + temb_ca_audio_gate .reshape (batch_size , temb_ca_audio_gate .shape [1 ], 1 , - 1 )
558- ).unbind (dim = 2 )
559-
560- audio_a2v_ca_scale , audio_a2v_ca_shift , audio_v2a_ca_scale , audio_v2a_ca_shift = audio_ca_scale_shift_table
561- v2a_gate = audio_ca_gate [0 ].squeeze (2 )
562-
563- if use_a2v_cross_attn :
564- # Audio-to-Video Cross Attention: Q: Video; K,V: Audio
565- mod_norm_hidden_states = norm_hidden_states * (
566- 1 + video_a2v_ca_scale .squeeze (2 )
567- ) + video_a2v_ca_shift .squeeze (2 )
568- mod_norm_audio_hidden_states = norm_audio_hidden_states * (
569- 1 + audio_a2v_ca_scale .squeeze (2 )
570- ) + audio_a2v_ca_shift .squeeze (2 )
571-
572- a2v_attn_hidden_states = self .audio_to_video_attn (
573- mod_norm_hidden_states ,
574- encoder_hidden_states = mod_norm_audio_hidden_states ,
575- query_rotary_emb = ca_video_rotary_emb ,
576- key_rotary_emb = ca_audio_rotary_emb ,
577- attention_mask = a2v_cross_attention_mask ,
578- )
520+ norm_hidden_states = self .audio_to_video_norm (hidden_states )
521+ norm_audio_hidden_states = self .video_to_audio_norm (audio_hidden_states )
522+
523+ # Combine global and per-layer cross attention modulation parameters
524+ # Video
525+ video_per_layer_ca_scale_shift = self .video_a2v_cross_attn_scale_shift_table [:4 , :]
526+ video_per_layer_ca_gate = self .video_a2v_cross_attn_scale_shift_table [4 :, :]
527+
528+ video_ca_scale_shift_table = (
529+ video_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_scale_shift .dtype )
530+ + temb_ca_scale_shift .reshape (batch_size , temb_ca_scale_shift .shape [1 ], 4 , - 1 )
531+ ).unbind (dim = 2 )
532+ video_ca_gate = (
533+ video_per_layer_ca_gate [:, :, ...].to (temb_ca_gate .dtype )
534+ + temb_ca_gate .reshape (batch_size , temb_ca_gate .shape [1 ], 1 , - 1 )
535+ ).unbind (dim = 2 )
536+
537+ video_a2v_ca_scale , video_a2v_ca_shift , video_v2a_ca_scale , video_v2a_ca_shift = video_ca_scale_shift_table
538+ a2v_gate = video_ca_gate [0 ].squeeze (2 )
539+
540+ # Audio
541+ audio_per_layer_ca_scale_shift = self .audio_a2v_cross_attn_scale_shift_table [:4 , :]
542+ audio_per_layer_ca_gate = self .audio_a2v_cross_attn_scale_shift_table [4 :, :]
543+
544+ audio_ca_scale_shift_table = (
545+ audio_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_audio_scale_shift .dtype )
546+ + temb_ca_audio_scale_shift .reshape (batch_size , temb_ca_audio_scale_shift .shape [1 ], 4 , - 1 )
547+ ).unbind (dim = 2 )
548+ audio_ca_gate = (
549+ audio_per_layer_ca_gate [:, :, ...].to (temb_ca_audio_gate .dtype )
550+ + temb_ca_audio_gate .reshape (batch_size , temb_ca_audio_gate .shape [1 ], 1 , - 1 )
551+ ).unbind (dim = 2 )
552+
553+ audio_a2v_ca_scale , audio_a2v_ca_shift , audio_v2a_ca_scale , audio_v2a_ca_shift = audio_ca_scale_shift_table
554+ v2a_gate = audio_ca_gate [0 ].squeeze (2 )
555+
556+ # Audio-to-Video Cross Attention: Q: Video; K,V: Audio
557+ mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale .squeeze (2 )) + video_a2v_ca_shift .squeeze (
558+ 2
559+ )
560+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
561+ 1 + audio_a2v_ca_scale .squeeze (2 )
562+ ) + audio_a2v_ca_shift .squeeze (2 )
563+
564+ a2v_attn_hidden_states = self .audio_to_video_attn (
565+ mod_norm_hidden_states ,
566+ encoder_hidden_states = mod_norm_audio_hidden_states ,
567+ query_rotary_emb = ca_video_rotary_emb ,
568+ key_rotary_emb = ca_audio_rotary_emb ,
569+ attention_mask = a2v_cross_attention_mask ,
570+ )
579571
580- hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
581-
582- if use_v2a_cross_attn :
583- # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
584- mod_norm_hidden_states = norm_hidden_states * (
585- 1 + video_v2a_ca_scale .squeeze (2 )
586- ) + video_v2a_ca_shift .squeeze (2 )
587- mod_norm_audio_hidden_states = norm_audio_hidden_states * (
588- 1 + audio_v2a_ca_scale .squeeze (2 )
589- ) + audio_v2a_ca_shift .squeeze (2 )
590-
591- v2a_attn_hidden_states = self .video_to_audio_attn (
592- mod_norm_audio_hidden_states ,
593- encoder_hidden_states = mod_norm_hidden_states ,
594- query_rotary_emb = ca_audio_rotary_emb ,
595- key_rotary_emb = ca_video_rotary_emb ,
596- attention_mask = v2a_cross_attention_mask ,
597- )
572+ hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
573+
574+ # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
575+ mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale .squeeze (2 )) + video_v2a_ca_shift .squeeze (
576+ 2
577+ )
578+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
579+ 1 + audio_v2a_ca_scale .squeeze (2 )
580+ ) + audio_v2a_ca_shift .squeeze (2 )
581+
582+ v2a_attn_hidden_states = self .video_to_audio_attn (
583+ mod_norm_audio_hidden_states ,
584+ encoder_hidden_states = mod_norm_hidden_states ,
585+ query_rotary_emb = ca_audio_rotary_emb ,
586+ key_rotary_emb = ca_video_rotary_emb ,
587+ attention_mask = v2a_cross_attention_mask ,
588+ )
598589
599- audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
590+ audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
600591
601592 # 4. Feedforward
602593 norm_hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp ) + shift_mlp
0 commit comments