Skip to content

Commit 6a7e50c

Browse files
committed
change for use_prompt_embedding param
1 parent 699fd91 commit 6a7e50c

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,10 +737,25 @@ def __init__(
737737

738738
# 2. Prompt embeddings
739739
if self.use_prompt_embeddings:
740+
self.caption_projection = NNXPixArtAlphaTextProjection(
741+
rngs=rngs,
742+
in_features=self.caption_channels,
743+
hidden_size=inner_dim,
744+
dtype=self.dtype,
745+
weights_dtype=self.weights_dtype,
746+
)
747+
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
748+
rngs=rngs,
749+
in_features=self.caption_channels,
750+
hidden_size=audio_inner_dim,
751+
dtype=self.dtype,
752+
weights_dtype=self.weights_dtype,
753+
)
754+
else:
740755
self.caption_projection = None
741756
self.audio_caption_projection = None
742-
self.cross_attn_mod = True # Force True for LTX-2.3 prompt modulation
743757

758+
if self.cross_attn_mod:
744759
self.prompt_adaln = LTX2AdaLayerNormSingle(
745760
rngs=rngs,
746761
embedding_dim=inner_dim,
@@ -757,9 +772,6 @@ def __init__(
757772
dtype=self.dtype,
758773
weights_dtype=self.weights_dtype,
759774
)
760-
else:
761-
self.caption_projection = None
762-
self.audio_caption_projection = None
763775
# 3. Timestep Modulation Params and Embedding
764776
num_mod_params = 9 if self.cross_attn_mod else 6
765777
self.time_embed = LTX2AdaLayerNormSingle(

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,13 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
157157
ltx2_config["remat_policy"] = config.remat_policy
158158
ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved
159159
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
160+
ltx2_config["use_prompt_embeddings"] = True
160161

161162
if getattr(config, "model_name", "") == "ltx2.3":
162163
ltx2_config["gated_attn"] = True
163164
ltx2_config["cross_attn_mod"] = True
164165
ltx2_config["perturbed_attn"] = True
165-
ltx2_config["use_prompt_embeddings"] = True
166+
ltx2_config["use_prompt_embeddings"] = False
166167

167168
# 2. eval_shape
168169
p_model_factory = partial(create_model, ltx2_config=ltx2_config)

0 commit comments

Comments
 (0)