Skip to content

Commit 47defb1

Browse files
committed
feat: add DeepSeek-V3 support for MaxText to Hugging Face conversion
- Update architecture validation in checkpoint conversion to include MLA and MoE parameters. - Implement output projection initialization for MLA layers.
1 parent 3971206 commit 47defb1

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,22 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
140140
# Mapping from Hugging Face config attribute -> MaxText config attribute
141141
# Note: We use derived MaxText attributes (e.g. emb_dim) which account for scale factors.
142142
attributes_to_check = [
143-
("num_attention_heads", "num_query_heads"),
144-
("num_key_value_heads", "num_kv_heads"),
145143
("hidden_size", "emb_dim"),
146144
("intermediate_size", "mlp_dim"),
145+
("kv_lora_rank", "kv_lora_rank"),
146+
("moe_intermediate_size", "moe_mlp_dim"),
147+
("n_routed_experts", "num_experts"),
148+
("n_shared_experts", "shared_experts"),
149+
("num_attention_heads", "num_query_heads"),
150+
("num_experts", "num_experts"),
151+
("num_experts_per_tok", "num_experts_per_tok"),
147152
("num_hidden_layers", "num_decoder_layers"),
153+
("num_key_value_heads", "num_kv_heads"),
154+
("num_local_experts", "num_experts"),
155+
("q_lora_rank", "q_lora_rank"),
156+
("qk_nope_head_dim", "qk_nope_head_dim"),
157+
("qk_rope_head_dim", "qk_rope_head_dim"),
158+
("v_head_dim", "v_head_dim"),
148159
("vocab_size", "vocab_size"),
149160
]
150161

src/maxtext/layers/attention_mla.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,10 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
786786
rngs=self.rngs,
787787
)
788788

789+
@property
790+
def out_head_dim(self) -> int:
791+
return self.v_head_dim
792+
789793
def mla_query_projection(
790794
self, inputs_q: Array, inputs_positions: Array, model_mode
791795
) -> tuple[jax.Array, Optional[jax.Array]]:

src/maxtext/layers/attentions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,17 +696,21 @@ def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedShard
696696
query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...]
697697
return query, key, value
698698

699+
@property
700+
def out_head_dim(self) -> int:
701+
return self.head_dim
702+
699703
def init_out_w(self, output_dim: int) -> nnx.Module:
700704
"""out projection"""
701-
in_features = (self.num_query_heads, self.head_dim)
705+
in_features = (self.num_query_heads, self.out_head_dim)
702706
out_features = output_dim
703707
out_kernel_axis = (
704708
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
705709
)
706710
axis = (-2, -1)
707711

708712
if self.is_qwen3_next:
709-
in_features = self.num_query_heads * self.head_dim
713+
in_features = self.num_query_heads * self.out_head_dim
710714
out_kernel_axis = ("mlp", "embed")
711715
axis = (-1,)
712716

0 commit comments

Comments
 (0)