Skip to content

Commit 2052c22

Browse files
Merge pull request #3466 from AI-Hypercomputer:jimmytsai/fix-param-mapping-for-qwen2
PiperOrigin-RevId: 889232312
2 parents de51021 + f884db6 commit 2052c22

1 file changed

Lines changed: 79 additions & 48 deletions

File tree

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
7-
# https://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414

1515
"""Parameter mappings and transformation hooks for checkpoint conversion.
1616
@@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
587587
return mapping
588588

589589

590-
def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
591-
"""Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
590+
def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
591+
"""Returns mapping from MaxText to HuggingFace Qwen weight paths.
592592
593593
This function generates a dictionary that maps parameter names from a MaxText
594-
Qwen3 checkpoint to their corresponding names in the Hugging Face format.
594+
Qwen checkpoint to their corresponding names in the Hugging Face format.
595595
It handles both dense and Mixture-of-Experts (MoE) model variants.
596596
597597
Args:
@@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
631631
"params-decoder-layers-self_attention-value-kernel": [
632632
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
633633
],
634+
"params-decoder-layers-self_attention-query-bias": [
635+
f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers)
636+
],
637+
"params-decoder-layers-self_attention-key-bias": [
638+
f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers)
639+
],
640+
"params-decoder-layers-self_attention-value-bias": [
641+
f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers)
642+
],
634643
"params-decoder-layers-self_attention-out-kernel": [
635644
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
636645
],
@@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
688697
f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
689698
f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
690699
f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
700+
f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
701+
f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
702+
f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
691703
f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
692704
f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
693705
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
@@ -721,11 +733,11 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
721733
return mapping
722734

723735

724-
def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
725-
"""Creates parameter transformation functions for Qwen3.
736+
def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
737+
"""Creates parameter transformation functions for Qwen.
726738
727739
This function provides a dictionary of transformation functions (hooks) for
728-
converting Qwen3 model parameters between MaxText and Hugging Face formats.
740+
converting Qwen model parameters between MaxText and Hugging Face formats.
729741
It handles embedding padding and kernel reshaping.
730742
731743
Args:
@@ -766,6 +778,12 @@ def reshape_kernel(input_tensor, target_shape):
766778
else:
767779
return input_tensor.T.reshape(target_shape)
768780

781+
def reshape_bias(input_tensor, target_shape=None):
782+
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
783+
# saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
784+
# loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim]
785+
return input_tensor.reshape(target_shape)
786+
769787
mapping = {
770788
"params-token_embedder-embedding": pad_embedding_layer,
771789
"params-decoder-logits_dense-kernel": reshape_kernel,
@@ -780,6 +798,11 @@ def reshape_kernel(input_tensor, target_shape):
780798
"mlp-wi_1-kernel",
781799
"mlp-wo-kernel",
782800
]
801+
bias_hooks = [
802+
"self_attention-query-bias",
803+
"self_attention-key-bias",
804+
"self_attention-value-bias",
805+
]
783806
moe_kernel_hooks = [
784807
"moe_block-gate-kernel",
785808
"moe_block-wi_0-kernel",
@@ -793,13 +816,17 @@ def reshape_kernel(input_tensor, target_shape):
793816
if scan_layers:
794817
for key in kernel_hooks:
795818
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
819+
for key in bias_hooks:
820+
mapping[f"params-decoder-layers-{key}"] = reshape_bias
796821
if num_experts > 1:
797822
for key in moe_kernel_hooks:
798823
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
799824
else:
800825
for i in range(n_layers):
801826
for key in kernel_hooks:
802827
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
828+
for key in bias_hooks:
829+
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias
803830
if num_experts > 1:
804831
for key in moe_kernel_hooks:
805832
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
@@ -1376,7 +1403,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
13761403
# Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
13771404
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
13781405
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
1379-
text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(
1406+
text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
13801407
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
13811408
maxtext_config=maxtext_config,
13821409
scan_layers=scan_layers,
@@ -1544,7 +1571,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
15441571
# Text hooks, reusing QWEN3-MOE hook function
15451572
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
15461573
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
1547-
text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
1574+
text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
15481575
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
15491576
maxtext_config=maxtext_config,
15501577
scan_layers=scan_layers,
@@ -2332,24 +2359,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23322359
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
23332360
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
23342361
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
2335-
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2336-
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2337-
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2338-
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2339-
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2340-
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2341-
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2342-
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2343-
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2344-
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2345-
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2362+
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2363+
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2364+
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2365+
"qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2366+
"qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2367+
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2368+
"qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2369+
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2370+
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2371+
"qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2372+
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2373+
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2374+
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
23462375
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
23472376
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
23482377
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
2349-
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2350-
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2351-
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2352-
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2378+
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2379+
"qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2380+
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2381+
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
23532382
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
23542383
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
23552384
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
@@ -2370,24 +2399,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23702399
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23712400
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23722401
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2373-
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2374-
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2375-
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2376-
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2377-
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2378-
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2379-
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2380-
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2381-
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2382-
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2383-
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2402+
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2403+
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2404+
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2405+
"qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2406+
"qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2407+
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2408+
"qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2409+
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2410+
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2411+
"qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2412+
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2413+
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2414+
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23842415
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23852416
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23862417
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2387-
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2388-
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2389-
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2390-
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2418+
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2419+
"qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2420+
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2421+
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23912422
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23922423
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
23932424
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,

0 commit comments

Comments
 (0)