@@ -366,7 +366,6 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
366366 connector_kwargs = {
367367 "dtype" : jnp .float32 ,
368368 "weights_dtype" : config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
369- "attention_kernel" : config .attention if hasattr (config , "attention" ) else "flash" ,
370369 }
371370 if getattr (config , "model_name" , "" ) == "ltx2.3" :
372371 connector_kwargs .update (
@@ -885,10 +884,11 @@ def _get_gemma_prompt_embeds(
885884 text_encoder_hidden_states = text_encoder_outputs .hidden_states
886885 del text_encoder_outputs # Free memory
887886
888- states_np = [state .cpu ().to (torch .float32 ).numpy () for state in text_encoder_hidden_states ]
889- stacked_np = np .stack (states_np , axis = - 1 )
890- flattened_np = stacked_np .reshape (batch_size , text_input_ids .shape [1 ], - 1 )
891- prompt_embeds = jnp .array (flattened_np , dtype = jnp .bfloat16 )
887+ prompt_embeds_list = []
888+ for state in text_encoder_hidden_states :
889+ state_np = state .cpu ().to (torch .float32 ).numpy ()
890+ prompt_embeds_list .append (jnp .array (state_np , dtype = jnp .bfloat16 ))
891+ prompt_embeds = prompt_embeds_list
892892 del text_encoder_hidden_states # Free PyTorch tensor memory
893893
894894 prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().to (torch .float32 ).numpy (), dtype = jnp .bool_ )
0 commit comments