@@ -366,6 +366,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
366366 connector_kwargs = {
367367 "dtype" : jnp .float32 ,
368368 "weights_dtype" : config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
369+ "attention_kernel" : config .attention if hasattr (config , "attention" ) else "flash" ,
369370 }
370371 if getattr (config , "model_name" , "" ) == "ltx2.3" :
371372 connector_kwargs .update (
@@ -884,13 +885,10 @@ def _get_gemma_prompt_embeds(
884885 text_encoder_hidden_states = text_encoder_outputs .hidden_states
885886 del text_encoder_outputs # Free memory
886887
887- prompt_embeds_list = []
888- # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
889- for state in text_encoder_hidden_states :
890- state_np = state .cpu ().to (torch .float32 ).numpy ()
891- prompt_embeds_list .append (jnp .array (state_np , dtype = jnp .bfloat16 ))
892-
893- prompt_embeds = prompt_embeds_list
888+ states_np = [state .cpu ().to (torch .float32 ).numpy () for state in text_encoder_hidden_states ]
889+ stacked_np = np .stack (states_np , axis = - 1 )
890+ flattened_np = stacked_np .reshape (batch_size , text_input_ids .shape [1 ], - 1 )
891+ prompt_embeds = jnp .array (flattened_np , dtype = jnp .bfloat16 )
894892 del text_encoder_hidden_states # Free PyTorch tensor memory
895893
896894 prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().to (torch .float32 ).numpy (), dtype = jnp .bool_ )
0 commit comments