Skip to content

Commit 06e8de8

Browse files
committed
using ltx2adalayernormsingle
1 parent f631da3 commit 06e8de8

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def rename_for_ltx2_3_transformer(key):
6363
# key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
6464
# key = key.replace("prompt_scale_shift_table", "scale_shift_table")
6565

66-
if "prompt_adaln" in key:
67-
key = key.replace("prompt_adaln", "caption_projection")
68-
if "audio_prompt_adaln" in key:
69-
key = key.replace("audio_prompt_adaln", "audio_caption_projection")
7066
if "video_text_proj_in" in key:
7167
key = key.replace("video_text_proj_in", "feature_extractor.video_linear")
7268
if "audio_text_proj_in" in key:

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -737,17 +737,23 @@ def __init__(
737737

738738
# 2. Prompt embeddings
739739
if self.use_prompt_embeddings:
740-
self.caption_projection = NNXPixArtAlphaTextProjection(
740+
self.caption_projection = None
741+
self.audio_caption_projection = None
742+
self.cross_attn_mod = True # Force True for LTX-2.3 prompt modulation
743+
744+
self.prompt_adaln = LTX2AdaLayerNormSingle(
741745
rngs=rngs,
742-
in_features=self.caption_channels,
743-
hidden_size=inner_dim,
746+
embedding_dim=inner_dim,
747+
num_mod_params=2,
748+
use_additional_conditions=False,
744749
dtype=self.dtype,
745750
weights_dtype=self.weights_dtype,
746751
)
747-
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
752+
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
748753
rngs=rngs,
749-
in_features=self.audio_caption_channels,
750-
hidden_size=audio_inner_dim,
754+
embedding_dim=audio_inner_dim,
755+
num_mod_params=2,
756+
use_additional_conditions=False,
751757
dtype=self.dtype,
752758
weights_dtype=self.weights_dtype,
753759
)
@@ -1077,7 +1083,7 @@ def __call__(
10771083
temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1])
10781084
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
10791085

1080-
if self.cross_attn_mod and sigma is not None:
1086+
if self.use_prompt_embeddings and sigma is not None:
10811087
audio_sigma = audio_sigma if audio_sigma is not None else sigma
10821088
temb_prompt, _ = self.prompt_adaln(
10831089
sigma.flatten(),

0 commit comments

Comments
 (0)