Skip to content

Commit 25480de

Browse files
committed
debug
1 parent f3a6f8a commit 25480de

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,24 @@ def load_transformer_weights_2_3(
180180
if pt_key.startswith("audio_embeddings_connector") or pt_key.startswith("video_embeddings_connector"):
181181
continue
182182

183+
if "to_gate_logits" in pt_key and pt_key.endswith(".weight"):
184+
tensor = tensor.T
185+
183186
renamed_pt_key = rename_key(pt_key)
184187
renamed_pt_key = rename_for_ltx2_3_transformer(renamed_pt_key)
185188

189+
if "to_gate_logits" in pt_key:
190+
print(f"pt_key: {pt_key} -> renamed_pt_key: {renamed_pt_key}")
191+
186192
pt_tuple_key = tuple(renamed_pt_key.split("."))
187193

188194
flax_key, flax_tensor = get_key_and_value(
189195
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
190196
)
191197

198+
if "to_gate_logits" in pt_key:
199+
print(f"flax_key: {flax_key}")
200+
192201
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
193202

194203
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)