@@ -737,17 +737,23 @@ def __init__(
737737
738738 # 2. Prompt embeddings
739739 if self .use_prompt_embeddings :
740- self .caption_projection = NNXPixArtAlphaTextProjection (
740+ self .caption_projection = None
741+ self .audio_caption_projection = None
742+ self .cross_attn_mod = True # Force True for LTX-2.3 prompt modulation
743+
744+ self .prompt_adaln = LTX2AdaLayerNormSingle (
741745 rngs = rngs ,
742- in_features = self .caption_channels ,
743- hidden_size = inner_dim ,
746+ embedding_dim = inner_dim ,
747+ num_mod_params = 2 ,
748+ use_additional_conditions = False ,
744749 dtype = self .dtype ,
745750 weights_dtype = self .weights_dtype ,
746751 )
747- self .audio_caption_projection = NNXPixArtAlphaTextProjection (
752+ self .audio_prompt_adaln = LTX2AdaLayerNormSingle (
748753 rngs = rngs ,
749- in_features = self .audio_caption_channels ,
750- hidden_size = audio_inner_dim ,
754+ embedding_dim = audio_inner_dim ,
755+ num_mod_params = 2 ,
756+ use_additional_conditions = False ,
751757 dtype = self .dtype ,
752758 weights_dtype = self .weights_dtype ,
753759 )
@@ -1077,7 +1083,7 @@ def __call__(
10771083 temb_audio = temb_audio .reshape (batch_size , - 1 , temb_audio .shape [- 1 ])
10781084 audio_embedded_timestep = audio_embedded_timestep .reshape (batch_size , - 1 , audio_embedded_timestep .shape [- 1 ])
10791085
1080- if self .cross_attn_mod and sigma is not None :
1086+ if self .use_prompt_embeddings and sigma is not None :
10811087 audio_sigma = audio_sigma if audio_sigma is not None else sigma
10821088 temb_prompt , _ = self .prompt_adaln (
10831089 sigma .flatten (),
0 commit comments