Skip to content

Commit afb5173

Browse files
committed
LTX2.3 changes (except vocoder)
1 parent 217108d commit afb5173

4 files changed

Lines changed: 70 additions & 24 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@ save_config_to_gcs: False
2424
max_sequence_length: 1024
2525
sampler: "from_checkpoint"
2626

27-
# Generation parameters
27+
# Generation parameters (aligned with Diffusers LTX-2.3 docs: use_cross_timestep, modality + audio CFG)
2828
global_batch_size_to_train_on: 1
29-
num_inference_steps: 40
30-
guidance_scale: 4.0
31-
audio_guidance_scale: 4.0
29+
num_inference_steps: 30
30+
guidance_scale: 3.0
31+
guidance_rescale: 0.7
32+
audio_guidance_scale: 7.0
33+
audio_guidance_rescale: 0.7
3234
stg_scale: 1.0
3335
audio_stg_scale: 1.0
34-
modality_scale: 1.0
35-
audio_modality_scale: 1.0
36+
modality_scale: 3.0
37+
audio_modality_scale: 3.0
38+
use_cross_timestep: true
3639
spatio_temporal_guidance_blocks: [28]
3740
fps: 24
3841
pipeline_type: multi-scale
@@ -42,9 +45,9 @@ height: 512
4245
width: 768
4346
decode_timestep: 0.05
4447
decode_noise_scale: 0.025
48+
noise_scale: 0.0
4549
num_frames: 121
4650
quantization: "int8"
47-
seed: 10
4851
#parallelism
4952
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
5053
logical_axis_rules: [
@@ -109,5 +112,5 @@ bwd_quantization_calibration_method: "absmax"
109112
qwix_module_path: ".*"
110113
jit_initializers: True
111114
enable_single_replica_ckpt_restoring: False
112-
seed: 0
115+
seed: 10
113116
audio_format: "s16"

src/maxdiffusion/generate_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,20 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9393
num_frames=config.num_frames,
9494
num_inference_steps=config.num_inference_steps,
9595
guidance_scale=guidance_scale,
96+
guidance_rescale=getattr(config, "guidance_rescale", 0.0),
9697
generator=generator,
9798
frame_rate=getattr(config, "fps", 24.0),
9899
decode_timestep=getattr(config, "decode_timestep", 0.0),
99100
decode_noise_scale=getattr(config, "decode_noise_scale", None),
100101
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101102
audio_guidance_scale=getattr(config, "audio_guidance_scale", None),
103+
audio_guidance_rescale=getattr(config, "audio_guidance_rescale", None),
102104
stg_scale=getattr(config, "stg_scale", 0.0),
103105
audio_stg_scale=getattr(config, "audio_stg_scale", None),
104106
modality_scale=getattr(config, "modality_scale", 1.0),
105107
audio_modality_scale=getattr(config, "audio_modality_scale", None),
108+
use_cross_timestep=getattr(config, "use_cross_timestep", None),
109+
noise_scale=getattr(config, "noise_scale", 1.0),
106110
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
107111
)
108112
return out

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ def __init__(
112112
v2a_attention_kernel: str = "dot_product",
113113
flash_block_sizes: BlockSizes = None,
114114
flash_min_seq_length: int = 4096,
115+
perturbed_attn: bool = False,
115116
):
116117
self.dim = dim
117118
self.norm_eps = norm_eps
118119
self.norm_elementwise_affine = norm_elementwise_affine
119120
self.attention_kernel = attention_kernel
121+
self.perturbed_attn = perturbed_attn
120122

121123
# 1. Self-Attention (video and audio)
122124
self.norm1 = nnx.RMSNorm(
@@ -370,11 +372,11 @@ def __call__(
370372
audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
371373
ca_video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
372374
ca_audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
373-
attention_mask: Optional[jax.Array] = None,
374375
encoder_attention_mask: Optional[jax.Array] = None,
375376
audio_encoder_attention_mask: Optional[jax.Array] = None,
376377
a2v_cross_attention_mask: Optional[jax.Array] = None,
377378
v2a_cross_attention_mask: Optional[jax.Array] = None,
379+
perturbation_mask: Optional[jax.Array] = None,
378380
) -> Tuple[jax.Array, jax.Array]:
379381
batch_size = hidden_states.shape[0]
380382

@@ -419,6 +421,7 @@ def __call__(
419421
hidden_states=norm_hidden_states,
420422
encoder_hidden_states=None,
421423
rotary_emb=video_rotary_emb,
424+
perturbation_mask=perturbation_mask if self.perturbed_attn else None,
422425
)
423426
hidden_states = hidden_states + attn_hidden_states * gate_msa
424427

@@ -449,6 +452,7 @@ def __call__(
449452
hidden_states=norm_audio_hidden_states,
450453
encoder_hidden_states=None,
451454
rotary_emb=audio_rotary_emb,
455+
perturbation_mask=perturbation_mask if self.perturbed_attn else None,
452456
)
453457
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
454458

@@ -648,6 +652,7 @@ def __init__(
648652
gated_attn: bool = False,
649653
cross_attn_mod: bool = False,
650654
use_prompt_embeddings: bool = True,
655+
perturbed_attn: bool = False,
651656
spatio_temporal_guidance_blocks: Tuple[int, ...] = (),
652657
**kwargs,
653658
):
@@ -700,6 +705,7 @@ def __init__(
700705
self.attention_kernel = attention_kernel
701706
self.gated_attn = gated_attn
702707
self.cross_attn_mod = cross_attn_mod
708+
self.perturbed_attn = perturbed_attn
703709
self.a2v_attention_kernel = a2v_attention_kernel
704710
self.v2a_attention_kernel = v2a_attention_kernel
705711
self.flash_min_seq_length = flash_min_seq_length
@@ -943,6 +949,7 @@ def init_block(rngs):
943949
v2a_attention_kernel=self.v2a_attention_kernel,
944950
flash_block_sizes=flash_block_sizes,
945951
flash_min_seq_length=self.flash_min_seq_length,
952+
perturbed_attn=self.perturbed_attn,
946953
)
947954

948955
if self.scan_layers:
@@ -980,6 +987,7 @@ def init_block(rngs):
980987
v2a_attention_kernel=self.v2a_attention_kernel,
981988
flash_block_sizes=flash_block_sizes,
982989
flash_min_seq_length=self.flash_min_seq_length,
990+
perturbed_attn=self.perturbed_attn,
983991
)
984992
blocks.append(block)
985993
self.transformer_blocks = nnx.List(blocks)
@@ -1181,7 +1189,7 @@ def scan_fn(carry, block_and_mask):
11811189
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
11821190
a2v_cross_attention_mask=encoder_attention_mask,
11831191
v2a_cross_attention_mask=audio_encoder_attention_mask,
1184-
attention_mask=mask,
1192+
perturbation_mask=mask,
11851193
modality_mask=modality_mask,
11861194
)
11871195
return (
@@ -1225,6 +1233,9 @@ def scan_fn(carry, block_and_mask):
12251233
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
12261234
encoder_attention_mask=encoder_attention_mask,
12271235
audio_encoder_attention_mask=audio_encoder_attention_mask,
1236+
a2v_cross_attention_mask=encoder_attention_mask,
1237+
v2a_cross_attention_mask=audio_encoder_attention_mask,
1238+
perturbation_mask=mask,
12281239
)
12291240

12301241
# 6. Output layers

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,7 @@ def __call__(
12571257
stg_scale: float = 0.0,
12581258
modality_scale: float = 1.0,
12591259
audio_guidance_scale: Optional[float] = None,
1260+
audio_guidance_rescale: Optional[float] = None,
12601261
audio_stg_scale: Optional[float] = None,
12611262
audio_modality_scale: Optional[float] = None,
12621263
noise_scale: float = 1.0,
@@ -1274,13 +1275,20 @@ def __call__(
12741275
dtype: Optional[jnp.dtype] = None,
12751276
output_type: str = "pil",
12761277
return_dict: bool = True,
1277-
use_cross_timestep: bool = False,
1278+
use_cross_timestep: Optional[bool] = None,
12781279
):
12791280
# 1. Check inputs
12801281
self.check_inputs(
12811282
prompt, height, width, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
12821283
)
12831284

1285+
if use_cross_timestep is None:
1286+
use_cross_timestep = getattr(self.config, "model_name", "") == "ltx2.3"
1287+
1288+
audio_guidance_rescale = (
1289+
audio_guidance_rescale if audio_guidance_rescale is not None else guidance_rescale
1290+
)
1291+
12841292
# 2. Encode inputs (Text)
12851293
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
12861294
prompt,
@@ -1343,18 +1351,22 @@ def __call__(
13431351
latents=audio_latents,
13441352
)
13451353

1346-
# 5. Prepare Timesteps
1354+
# 5. Prepare Timesteps (match diffusers LTX2: shift uses scheduler config bounds, not latent token count)
13471355
sigmas = jnp.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
13481356

1349-
video_sequence_length = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1350-
video_sequence_length *= (height // self.vae_spatial_compression_ratio) * (width // self.vae_spatial_compression_ratio)
1357+
sched_cfg = self.scheduler.config
1358+
1359+
def _sched_cfg_get(key: str, default):
1360+
if hasattr(sched_cfg, "get"):
1361+
return sched_cfg.get(key, default)
1362+
return getattr(sched_cfg, key, default)
13511363

13521364
mu = calculate_shift(
1353-
video_sequence_length,
1354-
self.scheduler.config.get("base_image_seq_len", 1024),
1355-
self.scheduler.config.get("max_image_seq_len", 4096),
1356-
self.scheduler.config.get("base_shift", 0.95),
1357-
self.scheduler.config.get("max_shift", 2.05),
1365+
_sched_cfg_get("max_image_seq_len", 4096),
1366+
_sched_cfg_get("base_image_seq_len", 1024),
1367+
_sched_cfg_get("max_image_seq_len", 4096),
1368+
_sched_cfg_get("base_shift", 0.95),
1369+
_sched_cfg_get("max_shift", 2.05),
13581370
)
13591371

13601372
scheduler_state = retrieve_timesteps(
@@ -1373,7 +1385,7 @@ def __call__(
13731385
prompt_attention_mask_jax = prompt_attention_mask
13741386

13751387
do_cfg = guidance_scale > 1.0
1376-
do_stg = getattr(self.config, "stg_scale", 0.0) > 0.0
1388+
do_stg = stg_scale > 0.0
13771389

13781390
if do_cfg and do_stg:
13791391
negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1561,7 +1573,12 @@ def convert_to_vel(lat, x0):
15611573
audio_modality_delta = (audio_modality_scale - 1 if audio_modality_scale is not None else modality_scale - 1) * (x0_audio_text - x0_audio_isolated)
15621574

15631575
x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta + audio_modality_delta
1564-
1576+
1577+
if audio_guidance_rescale > 0:
1578+
x0_audio_combined = rescale_noise_cfg(
1579+
x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale
1580+
)
1581+
15651582
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
15661583

15671584
elif do_cfg:
@@ -1586,7 +1603,12 @@ def convert_to_vel(lat, x0):
15861603

15871604
cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1) * (x0_audio_text - x0_audio_uncond)
15881605
x0_audio_combined = x0_audio_text + cfg_audio_delta
1589-
1606+
1607+
if audio_guidance_rescale > 0:
1608+
x0_audio_combined = rescale_noise_cfg(
1609+
x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale
1610+
)
1611+
15901612
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
15911613

15921614
elif do_stg:
@@ -1791,8 +1813,14 @@ def transformer_forward_pass(
17911813
else:
17921814
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
17931815

1794-
N = latents.shape[0] // 4
1795-
modality_mask = jnp.concatenate([jnp.ones((3 * N, 1, 1), dtype=latents.dtype), jnp.zeros((N, 1, 1), dtype=latents.dtype)], axis=0)
1816+
b = latents.shape[0]
1817+
if b % 4 == 0 and b > 0:
1818+
n = b // 4
1819+
modality_mask = jnp.concatenate(
1820+
[jnp.ones((3 * n, 1, 1), dtype=latents.dtype), jnp.zeros((n, 1, 1), dtype=latents.dtype)], axis=0
1821+
)
1822+
else:
1823+
modality_mask = jnp.ones((b, 1, 1), dtype=latents.dtype)
17961824

17971825
noise_pred, noise_pred_audio = transformer(
17981826
hidden_states=latents,

0 commit comments

Comments
 (0)