Skip to content

Commit 2b057eb

Browse files
Merge pull request #3530 from AI-Hypercomputer:yujiedeng/to_hugging_face_convertion
PiperOrigin-RevId: 899242180
2 parents 35e07f9 + 47defb1 commit 2b057eb

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,22 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
140140
# Mapping from Hugging Face config attribute -> MaxText config attribute
141141
# Note: We use derived MaxText attributes (e.g. emb_dim) which account for scale factors.
142142
attributes_to_check = [
143-
("num_attention_heads", "num_query_heads"),
144-
("num_key_value_heads", "num_kv_heads"),
145143
("hidden_size", "emb_dim"),
146144
("intermediate_size", "mlp_dim"),
145+
("kv_lora_rank", "kv_lora_rank"),
146+
("moe_intermediate_size", "moe_mlp_dim"),
147+
("n_routed_experts", "num_experts"),
148+
("n_shared_experts", "shared_experts"),
149+
("num_attention_heads", "num_query_heads"),
150+
("num_experts", "num_experts"),
151+
("num_experts_per_tok", "num_experts_per_tok"),
147152
("num_hidden_layers", "num_decoder_layers"),
153+
("num_key_value_heads", "num_kv_heads"),
154+
("num_local_experts", "num_experts"),
155+
("q_lora_rank", "q_lora_rank"),
156+
("qk_nope_head_dim", "qk_nope_head_dim"),
157+
("qk_rope_head_dim", "qk_rope_head_dim"),
158+
("v_head_dim", "v_head_dim"),
148159
("vocab_size", "vocab_size"),
149160
]
150161

src/maxtext/layers/attention_mla.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,10 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
851851
rngs=self.rngs,
852852
)
853853

854+
@property
855+
def out_head_dim(self) -> int:
856+
return self.v_head_dim
857+
854858
def mla_query_projection(
855859
self, inputs_q: Array, inputs_positions: Array, model_mode
856860
) -> tuple[jax.Array, Optional[jax.Array]]:

src/maxtext/layers/attentions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,17 +702,21 @@ def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedShard
702702
query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...]
703703
return query, key, value
704704

705+
@property
706+
def out_head_dim(self) -> int:
707+
return self.head_dim
708+
705709
def init_out_w(self, output_dim: int) -> nnx.Module:
706710
"""out projection"""
707-
in_features = (self.num_query_heads, self.head_dim)
711+
in_features = (self.num_query_heads, self.out_head_dim)
708712
out_features = output_dim
709713
out_kernel_axis = (
710714
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
711715
)
712716
axis = (-2, -1)
713717

714718
if self.is_qwen3_next:
715-
in_features = self.num_query_heads * self.head_dim
719+
in_features = self.num_query_heads * self.out_head_dim
716720
out_kernel_axis = ("mlp", "embed")
717721
axis = (-1,)
718722

0 commit comments

Comments
 (0)