Skip to content

Commit 1ab10eb

Browse files
committed
flash cfg+stg introduced
1 parent a7e43fc commit 1ab10eb

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14971497
frame_rate,
14981498
perturbation_mask=perturbation_mask,
14991499
use_cross_timestep=use_cross_timestep,
1500+
is_cfg_stg_mode=do_cfg and do_stg,
15001501
)
15011502

15021503
do_cfg = guidance_scale > 1.0
@@ -1785,6 +1786,7 @@ def transformer_forward_pass(
17851786
sigma=None,
17861787
audio_sigma=None,
17871788
use_cross_timestep=False,
1789+
is_cfg_stg_mode: bool = False,
17881790
):
17891791
transformer = nnx.merge(graphdef, state)
17901792

@@ -1802,7 +1804,7 @@ def transformer_forward_pass(
18021804
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
18031805

18041806
b = latents.shape[0]
1805-
if b % 4 == 0 and b > 0:
1807+
if is_cfg_stg_mode:
18061808
n = b // 4
18071809
modality_mask = jnp.concatenate(
18081810
[jnp.ones((3 * n, 1, 1), dtype=latents.dtype), jnp.zeros((n, 1, 1), dtype=latents.dtype)], axis=0

0 commit comments

Comments
 (0)