Skip to content

Commit 50d2db0

Browse files
committed
remove debug+ gated attn=true
1 parent 25480de commit 50d2db0

2 files changed

Lines changed: 1 addition & 7 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,12 @@ def load_transformer_weights_2_3(
186186
renamed_pt_key = rename_key(pt_key)
187187
renamed_pt_key = rename_for_ltx2_3_transformer(renamed_pt_key)
188188

189-
if "to_gate_logits" in pt_key:
190-
print(f"pt_key: {pt_key} -> renamed_pt_key: {renamed_pt_key}")
191-
192189
pt_tuple_key = tuple(renamed_pt_key.split("."))
193190

194191
flax_key, flax_tensor = get_key_and_value(
195192
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
196193
)
197194

198-
if "to_gate_logits" in pt_key:
199-
print(f"flax_key: {flax_key}")
200-
201195
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
202196

203197
validate_flax_state_dict(eval_shapes, flax_state_dict)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
159159
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
160160

161161
if getattr(config, "model_name", "") == "ltx2.3":
162-
ltx2_config["gated_attn"] = False
162+
ltx2_config["gated_attn"] = True
163163
ltx2_config["cross_attn_mod"] = True
164164
ltx2_config["perturbed_attn"] = True
165165
ltx2_config["use_prompt_embeddings"] = True

0 commit comments

Comments
 (0)