Skip to content

Commit 217108d

Browse files
committed
revert
1 parent c3cafb7 commit 217108d

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
366366
connector_kwargs = {
367367
"dtype": jnp.float32,
368368
"weights_dtype": config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
369-
"attention_kernel": config.attention if hasattr(config, "attention") else "flash",
370369
}
371370
if getattr(config, "model_name", "") == "ltx2.3":
372371
connector_kwargs.update(
@@ -885,10 +884,11 @@ def _get_gemma_prompt_embeds(
885884
text_encoder_hidden_states = text_encoder_outputs.hidden_states
886885
del text_encoder_outputs # Free memory
887886

888-
states_np = [state.cpu().to(torch.float32).numpy() for state in text_encoder_hidden_states]
889-
stacked_np = np.stack(states_np, axis=-1)
890-
flattened_np = stacked_np.reshape(batch_size, text_input_ids.shape[1], -1)
891-
prompt_embeds = jnp.array(flattened_np, dtype=jnp.bfloat16)
887+
prompt_embeds_list = []
888+
for state in text_encoder_hidden_states:
889+
state_np = state.cpu().to(torch.float32).numpy()
890+
prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16))
891+
prompt_embeds = prompt_embeds_list
892892
del text_encoder_hidden_states # Free PyTorch tensor memory
893893

894894
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_)

0 commit comments

Comments
 (0)