Skip to content

Commit 1fbcc16

Browse files
committed
changes in modality mask
1 parent 703ce71 commit 1fbcc16

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,7 @@ def transformer_forward_pass(
18071807
else:
18081808
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
18091809

1810+
b = latents.shape[0]
18101811
n = b // 4
18111812
stg_mask = jnp.concatenate(
18121813
[jnp.ones((b - n, 1, 1), dtype=latents.dtype), jnp.zeros((n, 1, 1), dtype=latents.dtype)], axis=0

0 commit comments

Comments
 (0)