Skip to content

Commit 53cfaae

Browse files
committed
del unused NNXCombinedTimestepTextProjEmbeddings and use NNXPixArtAlphaTextProjection
1 parent f53112d commit 53cfaae

3 files changed

Lines changed: 20 additions & 94 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -473,56 +473,6 @@ def __call__(self, timestep, pooled_projection):
473473
conditioning = timestep_emb + pooled_projections
474474
return conditioning
475475

476-
class NNXCombinedTimestepTextProjEmbeddings(nnx.Module):
477-
def __init__(
478-
self,
479-
rngs: nnx.Rngs,
480-
in_features: int,
481-
hidden_size: int,
482-
embedding_dim: int,
483-
out_features: int = None,
484-
act_fn: str = "gelu_tanh",
485-
dtype: jnp.dtype = jnp.float32,
486-
weights_dtype: jnp.dtype = jnp.float32,
487-
precision: jax.lax.Precision = None,
488-
):
489-
if out_features is None:
490-
out_features = hidden_size
491-
492-
self.linear = nnx.Linear(
493-
rngs=rngs,
494-
in_features=in_features,
495-
out_features=out_features,
496-
use_bias=True,
497-
dtype=jnp.float32,
498-
param_dtype=weights_dtype,
499-
precision=precision,
500-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed")),
501-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
502-
)
503-
504-
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
505-
506-
class EmbWrapper(nnx.Module):
507-
def __init__(self, rngs: nnx.Rngs, embedding_dim: int, weights_dtype: jnp.dtype):
508-
self.timestep_embedder = NNXTimestepEmbedding(
509-
rngs=rngs,
510-
in_channels=256,
511-
time_embed_dim=embedding_dim,
512-
dtype=jnp.float32,
513-
weights_dtype=weights_dtype,
514-
)
515-
516-
self.emb = EmbWrapper(rngs, embedding_dim, weights_dtype)
517-
518-
def __call__(self, caption, timestep):
519-
hidden_states = self.linear(caption)
520-
521-
timesteps_proj = self.time_proj(timestep)
522-
timesteps_emb = self.emb.timestep_embedder(timesteps_proj)
523-
524-
return hidden_states + timesteps_emb[:, None, :]
525-
526476

527477
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
528478
embedding_dim: int

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def rename_for_ltx2_3_transformer(key):
6363
# key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
6464
# key = key.replace("prompt_scale_shift_table", "scale_shift_table")
6565

66-
# if "prompt_adaln" in key:
67-
# key = key.replace("prompt_adaln", "caption_projection")
68-
# if "audio_prompt_adaln" in key:
69-
# key = key.replace("audio_prompt_adaln", "audio_caption_projection")
66+
if "prompt_adaln" in key:
67+
key = key.replace("prompt_adaln", "caption_projection")
68+
if "audio_prompt_adaln" in key:
69+
key = key.replace("audio_prompt_adaln", "audio_caption_projection")
7070
if "video_text_proj_in" in key:
7171
key = key.replace("video_text_proj_in", "feature_extractor.video_linear")
7272
if "audio_text_proj_in" in key:

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -737,38 +737,20 @@ def __init__(
737737

738738
# 2. Prompt embeddings
739739
if self.use_prompt_embeddings:
740-
if self.cross_attn_mod:
741-
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
742-
rngs=rngs,
743-
in_features=self.caption_channels,
744-
hidden_size=self.cross_attention_dim,
745-
embedding_dim=self.cross_attention_dim,
746-
dtype=self.dtype,
747-
weights_dtype=self.weights_dtype,
748-
)
749-
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
750-
rngs=rngs,
751-
in_features=self.audio_caption_channels,
752-
hidden_size=self.audio_cross_attention_dim,
753-
embedding_dim=self.audio_cross_attention_dim,
754-
dtype=self.dtype,
755-
weights_dtype=self.weights_dtype,
756-
)
757-
else:
758-
self.caption_projection = NNXPixArtAlphaTextProjection(
759-
rngs=rngs,
760-
in_features=self.caption_channels,
761-
hidden_size=inner_dim,
762-
dtype=self.dtype,
763-
weights_dtype=self.weights_dtype,
764-
)
765-
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
766-
rngs=rngs,
767-
in_features=self.audio_caption_channels,
768-
hidden_size=audio_inner_dim,
769-
dtype=self.dtype,
770-
weights_dtype=self.weights_dtype,
771-
)
740+
self.caption_projection = NNXPixArtAlphaTextProjection(
741+
rngs=rngs,
742+
in_features=self.caption_channels,
743+
hidden_size=inner_dim,
744+
dtype=self.dtype,
745+
weights_dtype=self.weights_dtype,
746+
)
747+
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
748+
rngs=rngs,
749+
in_features=self.audio_caption_channels,
750+
hidden_size=audio_inner_dim,
751+
dtype=self.dtype,
752+
weights_dtype=self.weights_dtype,
753+
)
772754
else:
773755
self.caption_projection = None
774756
self.audio_caption_projection = None
@@ -1146,14 +1128,8 @@ def __call__(
11461128
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
11471129

11481130
if self.use_prompt_embeddings:
1149-
if self.cross_attn_mod:
1150-
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
1151-
audio_encoder_hidden_states = self.audio_caption_projection(
1152-
audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep
1153-
)
1154-
else:
1155-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
1156-
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
1131+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
1132+
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
11571133

11581134
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
11591135
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])

0 commit comments

Comments
 (0)