Skip to content

Commit c3cafb7

Browse files
committed
Text Encoder Layer Stacking
1 parent 5778ffc commit c3cafb7

2 files changed

Lines changed: 9 additions & 11 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ sampler: "from_checkpoint"
2727
# Generation parameters
2828
global_batch_size_to_train_on: 1
2929
num_inference_steps: 40
30-
guidance_scale: 3.0
31-
audio_guidance_scale: 7.0
30+
guidance_scale: 4.0
31+
audio_guidance_scale: 4.0
3232
stg_scale: 1.0
3333
audio_stg_scale: 1.0
34-
modality_scale: 3.0
35-
audio_modality_scale: 3.0
34+
modality_scale: 1.0
35+
audio_modality_scale: 1.0
3636
spatio_temporal_guidance_blocks: [28]
3737
fps: 24
3838
pipeline_type: multi-scale

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
366366
connector_kwargs = {
367367
"dtype": jnp.float32,
368368
"weights_dtype": config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
369+
"attention_kernel": config.attention if hasattr(config, "attention") else "flash",
369370
}
370371
if getattr(config, "model_name", "") == "ltx2.3":
371372
connector_kwargs.update(
@@ -884,13 +885,10 @@ def _get_gemma_prompt_embeds(
884885
text_encoder_hidden_states = text_encoder_outputs.hidden_states
885886
del text_encoder_outputs # Free memory
886887

887-
prompt_embeds_list = []
888-
# Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
889-
for state in text_encoder_hidden_states:
890-
state_np = state.cpu().to(torch.float32).numpy()
891-
prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16))
892-
893-
prompt_embeds = prompt_embeds_list
888+
states_np = [state.cpu().to(torch.float32).numpy() for state in text_encoder_hidden_states]
889+
stacked_np = np.stack(states_np, axis=-1)
890+
flattened_np = stacked_np.reshape(batch_size, text_input_ids.shape[1], -1)
891+
prompt_embeds = jnp.array(flattened_np, dtype=jnp.bfloat16)
894892
del text_encoder_hidden_states # Free PyTorch tensor memory
895893

896894
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_)

0 commit comments

Comments
 (0)