@@ -737,38 +737,20 @@ def __init__(
737737
738738 # 2. Prompt embeddings
739739 if self .use_prompt_embeddings :
740- if self .cross_attn_mod :
741- self .caption_projection = NNXCombinedTimestepTextProjEmbeddings (
742- rngs = rngs ,
743- in_features = self .caption_channels ,
744- hidden_size = self .cross_attention_dim ,
745- embedding_dim = self .cross_attention_dim ,
746- dtype = self .dtype ,
747- weights_dtype = self .weights_dtype ,
748- )
749- self .audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings (
750- rngs = rngs ,
751- in_features = self .audio_caption_channels ,
752- hidden_size = self .audio_cross_attention_dim ,
753- embedding_dim = self .audio_cross_attention_dim ,
754- dtype = self .dtype ,
755- weights_dtype = self .weights_dtype ,
756- )
757- else :
758- self .caption_projection = NNXPixArtAlphaTextProjection (
759- rngs = rngs ,
760- in_features = self .caption_channels ,
761- hidden_size = inner_dim ,
762- dtype = self .dtype ,
763- weights_dtype = self .weights_dtype ,
764- )
765- self .audio_caption_projection = NNXPixArtAlphaTextProjection (
766- rngs = rngs ,
767- in_features = self .audio_caption_channels ,
768- hidden_size = audio_inner_dim ,
769- dtype = self .dtype ,
770- weights_dtype = self .weights_dtype ,
771- )
740+ self .caption_projection = NNXPixArtAlphaTextProjection (
741+ rngs = rngs ,
742+ in_features = self .caption_channels ,
743+ hidden_size = inner_dim ,
744+ dtype = self .dtype ,
745+ weights_dtype = self .weights_dtype ,
746+ )
747+ self .audio_caption_projection = NNXPixArtAlphaTextProjection (
748+ rngs = rngs ,
749+ in_features = self .audio_caption_channels ,
750+ hidden_size = audio_inner_dim ,
751+ dtype = self .dtype ,
752+ weights_dtype = self .weights_dtype ,
753+ )
772754 else :
773755 self .caption_projection = None
774756 self .audio_caption_projection = None
@@ -1146,14 +1128,8 @@ def __call__(
11461128 audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
11471129
11481130 if self .use_prompt_embeddings :
1149- if self .cross_attn_mod :
1150- encoder_hidden_states = self .caption_projection (encoder_hidden_states , timestep )
1151- audio_encoder_hidden_states = self .audio_caption_projection (
1152- audio_encoder_hidden_states , audio_timestep if audio_timestep is not None else timestep
1153- )
1154- else :
1155- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
1156- audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
1131+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
1132+ audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
11571133
11581134 encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
11591135 audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
0 commit comments