Skip to content

Commit 703ce71

Browse files
committed
changes in modality mask
1 parent ccd336f commit 703ce71

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,14 +1807,12 @@ def transformer_forward_pass(
18071807
else:
18081808
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
18091809

1810-
b = latents.shape[0]
1811-
if is_cfg_stg_mode:
1812-
n = b // 4
1813-
modality_mask = jnp.concatenate(
1814-
[jnp.ones((3 * n, 1, 1), dtype=latents.dtype), jnp.zeros((n, 1, 1), dtype=latents.dtype)], axis=0
1815-
)
1816-
else:
1817-
modality_mask = jnp.ones((b, 1, 1), dtype=latents.dtype)
1810+
n = b // 4
1811+
stg_mask = jnp.concatenate(
1812+
[jnp.ones((b - n, 1, 1), dtype=latents.dtype), jnp.zeros((n, 1, 1), dtype=latents.dtype)], axis=0
1813+
)
1814+
ones_mask = jnp.ones((b, 1, 1), dtype=latents.dtype)
1815+
modality_mask = jnp.where(is_cfg_stg_mode, stg_mask, ones_mask)
18181816

18191817
noise_pred, noise_pred_audio = transformer(
18201818
hidden_states=latents,

0 commit comments

Comments
 (0)