Skip to content

Commit 8024e66

Browse files
committed
fix
1 parent 271400e commit 8024e66

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,7 @@ def transformer_forward_pass(
17941794
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
17951795

17961796
N = latents.shape[0] // 4
1797-
modality_mask = jnp.concatenate([jnp.ones((3 * N, 1, 1, 1), dtype=latents.dtype), jnp.zeros((N, 1, 1, 1), dtype=latents.dtype)], axis=0)
1797+
modality_mask = jnp.concatenate([jnp.ones((3 * N, 1, 1), dtype=latents.dtype), jnp.zeros((N, 1, 1), dtype=latents.dtype)], axis=0)
17981798

17991799
noise_pred, noise_pred_audio = transformer(
18001800
hidden_states=latents,

0 commit comments

Comments
 (0)