File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -737,10 +737,25 @@ def __init__(
737737
738738 # 2. Prompt embeddings
739739 if self .use_prompt_embeddings :
740+ self .caption_projection = NNXPixArtAlphaTextProjection (
741+ rngs = rngs ,
742+ in_features = self .caption_channels ,
743+ hidden_size = inner_dim ,
744+ dtype = self .dtype ,
745+ weights_dtype = self .weights_dtype ,
746+ )
747+ self .audio_caption_projection = NNXPixArtAlphaTextProjection (
748+ rngs = rngs ,
749+ in_features = self .caption_channels ,
750+ hidden_size = audio_inner_dim ,
751+ dtype = self .dtype ,
752+ weights_dtype = self .weights_dtype ,
753+ )
754+ else :
740755 self .caption_projection = None
741756 self .audio_caption_projection = None
742- self .cross_attn_mod = True # Force True for LTX-2.3 prompt modulation
743757
758+ if self .cross_attn_mod :
744759 self .prompt_adaln = LTX2AdaLayerNormSingle (
745760 rngs = rngs ,
746761 embedding_dim = inner_dim ,
@@ -757,9 +772,6 @@ def __init__(
757772 dtype = self .dtype ,
758773 weights_dtype = self .weights_dtype ,
759774 )
760- else :
761- self .caption_projection = None
762- self .audio_caption_projection = None
763775 # 3. Timestep Modulation Params and Embedding
764776 num_mod_params = 9 if self .cross_attn_mod else 6
765777 self .time_embed = LTX2AdaLayerNormSingle (
Original file line number Diff line number Diff line change @@ -157,12 +157,13 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
157157 ltx2_config ["remat_policy" ] = config .remat_policy
158158 ltx2_config ["names_which_can_be_saved" ] = config .names_which_can_be_saved
159159 ltx2_config ["names_which_can_be_offloaded" ] = config .names_which_can_be_offloaded
160+ ltx2_config ["use_prompt_embeddings" ] = True
160161
161162 if getattr (config , "model_name" , "" ) == "ltx2.3" :
162163 ltx2_config ["gated_attn" ] = True
163164 ltx2_config ["cross_attn_mod" ] = True
164165 ltx2_config ["perturbed_attn" ] = True
165- ltx2_config ["use_prompt_embeddings" ] = True
166+ ltx2_config ["use_prompt_embeddings" ] = False
166167
167168 # 2. eval_shape
168169 p_model_factory = partial (create_model , ltx2_config = ltx2_config )
You can’t perform that action at this time.
0 commit comments