Skip to content

Commit 699fd91

Browse files
committed
fix
1 parent 50d2db0 commit 699fd91

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,15 +1132,12 @@ def __call__(
11321132
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
11331133
)
11341134
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
1135-
1136-
if self.use_prompt_embeddings:
1135+
if self.use_prompt_embeddings and self.caption_projection is not None:
11371136
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
11381137
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
11391138

11401139
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
1141-
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
1142-
1143-
# Construct perturbation_mask_per_layer for STG
1140+
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) # Construct perturbation_mask_per_layer for STG
11441141
if perturbation_mask is None:
11451142
perturbation_mask_per_layer = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype)
11461143
else:

0 commit comments

Comments
 (0)