1- # Copyright 2023–2025 Google LLC
1+ # Copyright 2023–2026 Google LLC
22#
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
66#
7- # https://www.apache.org/licenses/LICENSE-2.0
7+ # https://www.apache.org/licenses/LICENSE-2.0
88#
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
1414
1515"""Parameter mappings and transformation hooks for checkpoint conversion.
1616
@@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
587587 return mapping
588588
589589
590- def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591- """Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
590+ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591+ """Returns mapping from MaxText to HuggingFace Qwen weight paths.
592592
593593 This function generates a dictionary that maps parameter names from a MaxText
594- Qwen3 checkpoint to their corresponding names in the Hugging Face format.
594+ Qwen checkpoint to their corresponding names in the Hugging Face format.
595595 It handles both dense and Mixture-of-Experts (MoE) model variants.
596596
597597 Args:
@@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
631631 "params-decoder-layers-self_attention-value-kernel" : [
632632 f"model.layers.{ i } .self_attn.v_proj.weight" for i in range (n_layers )
633633 ],
634+ "params-decoder-layers-self_attention-query-bias" : [
635+ f"model.layers.{ i } .self_attn.q_proj.bias" for i in range (n_layers )
636+ ],
637+ "params-decoder-layers-self_attention-key-bias" : [
638+ f"model.layers.{ i } .self_attn.k_proj.bias" for i in range (n_layers )
639+ ],
640+ "params-decoder-layers-self_attention-value-bias" : [
641+ f"model.layers.{ i } .self_attn.v_proj.bias" for i in range (n_layers )
642+ ],
634643 "params-decoder-layers-self_attention-out-kernel" : [
635644 f"model.layers.{ i } .self_attn.o_proj.weight" for i in range (n_layers )
636645 ],
@@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
688697 f"params-decoder-layers_{ i } -self_attention-key-kernel" : f"model.layers.{ i } .self_attn.k_proj.weight" ,
689698 f"params-decoder-layers_{ i } -self_attention-value-kernel" : f"model.layers.{ i } .self_attn.v_proj.weight" ,
690699 f"params-decoder-layers_{ i } -self_attention-out-kernel" : f"model.layers.{ i } .self_attn.o_proj.weight" ,
700+ f"params-decoder-layers_{ i } -self_attention-query-bias" : f"model.layers.{ i } .self_attn.q_proj.bias" ,
701+ f"params-decoder-layers_{ i } -self_attention-key-bias" : f"model.layers.{ i } .self_attn.k_proj.bias" ,
702+ f"params-decoder-layers_{ i } -self_attention-value-bias" : f"model.layers.{ i } .self_attn.v_proj.bias" ,
691703 f"params-decoder-layers_{ i } -self_attention-query_norm-scale" : f"model.layers.{ i } .self_attn.q_norm.weight" ,
692704 f"params-decoder-layers_{ i } -self_attention-key_norm-scale" : f"model.layers.{ i } .self_attn.k_norm.weight" ,
693705 f"params-decoder-layers_{ i } -post_self_attention_layer_norm-scale" : f"model.layers.{ i } .post_attention_layernorm.weight" ,
@@ -721,11 +733,11 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
721733 return mapping
722734
723735
724- def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
725- """Creates parameter transformation functions for Qwen3 .
736+ def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
737+ """Creates parameter transformation functions for Qwen .
726738
727739 This function provides a dictionary of transformation functions (hooks) for
728- converting Qwen3 model parameters between MaxText and Hugging Face formats.
740+ converting Qwen model parameters between MaxText and Hugging Face formats.
729741 It handles embedding padding and kernel reshaping.
730742
731743 Args:
@@ -766,6 +778,12 @@ def reshape_kernel(input_tensor, target_shape):
766778 else :
767779 return input_tensor .T .reshape (target_shape )
768780
781+ def reshape_bias (input_tensor , target_shape = None ):
782+ """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
783+ # saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
784+ # loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim]
785+ return input_tensor .reshape (target_shape )
786+
769787 mapping = {
770788 "params-token_embedder-embedding" : pad_embedding_layer ,
771789 "params-decoder-logits_dense-kernel" : reshape_kernel ,
@@ -780,6 +798,11 @@ def reshape_kernel(input_tensor, target_shape):
780798 "mlp-wi_1-kernel" ,
781799 "mlp-wo-kernel" ,
782800 ]
801+ bias_hooks = [
802+ "self_attention-query-bias" ,
803+ "self_attention-key-bias" ,
804+ "self_attention-value-bias" ,
805+ ]
783806 moe_kernel_hooks = [
784807 "moe_block-gate-kernel" ,
785808 "moe_block-wi_0-kernel" ,
@@ -793,13 +816,17 @@ def reshape_kernel(input_tensor, target_shape):
793816 if scan_layers :
794817 for key in kernel_hooks :
795818 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
819+ for key in bias_hooks :
820+ mapping [f"params-decoder-layers-{ key } " ] = reshape_bias
796821 if num_experts > 1 :
797822 for key in moe_kernel_hooks :
798823 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
799824 else :
800825 for i in range (n_layers ):
801826 for key in kernel_hooks :
802827 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
828+ for key in bias_hooks :
829+ mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_bias
803830 if num_experts > 1 :
804831 for key in moe_kernel_hooks :
805832 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
@@ -1376,7 +1403,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
13761403 # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
13771404 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
13781405 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1379- text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (
1406+ text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
13801407 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
13811408 maxtext_config = maxtext_config ,
13821409 scan_layers = scan_layers ,
@@ -1544,7 +1571,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
15441571 # Text hooks, reusing QWEN3-MOE hook function
15451572 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
15461573 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1547- text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (
1574+ text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (
15481575 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
15491576 maxtext_config = maxtext_config ,
15501577 scan_layers = scan_layers ,
@@ -2332,24 +2359,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23322359 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23332360 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23342361 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2335- "qwen3-0.6b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2336- "qwen3-1.7b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2337- "qwen3-1.7b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2338- "qwen3-4b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2339- "qwen3-4b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2340- "qwen3-4b-thinking-2507" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2341- "qwen3-8b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2342- "qwen3-8b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2343- "qwen3-14b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2344- "qwen3-14b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2345- "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2362+ "qwen2.5-7b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2363+ "qwen2.5-14b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2364+ "qwen3-0.6b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2365+ "qwen3-1.7b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2366+ "qwen3-1.7b-base" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2367+ "qwen3-4b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2368+ "qwen3-4b-base" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2369+ "qwen3-4b-thinking-2507" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2370+ "qwen3-8b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2371+ "qwen3-8b-base" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2372+ "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2373+ "qwen3-14b-base" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2374+ "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
23462375 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23472376 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23482377 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
2349- "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2350- "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2351- "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2352- "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2378+ "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2379+ "qwen3-30b-a3b-base" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2380+ "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2381+ "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
23532382 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING ,
23542383 "gpt-oss-20b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
23552384 "gpt-oss-120b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
@@ -2370,24 +2399,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23702399 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23712400 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23722401 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2373- "qwen3-0.6b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2374- "qwen3-1.7b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2375- "qwen3-1.7b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2376- "qwen3-4b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2377- "qwen3-4b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2378- "qwen3-4b-thinking-2507" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2379- "qwen3-8b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2380- "qwen3-8b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2381- "qwen3-14b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2382- "qwen3-14b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2383- "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2402+ "qwen2.5-7b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2403+ "qwen2.5-14b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2404+ "qwen3-0.6b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2405+ "qwen3-1.7b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2406+ "qwen3-1.7b-base" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2407+ "qwen3-4b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2408+ "qwen3-4b-base" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2409+ "qwen3-4b-thinking-2507" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2410+ "qwen3-8b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2411+ "qwen3-8b-base" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2412+ "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2413+ "qwen3-14b-base" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2414+ "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23842415 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23852416 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23862417 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2387- "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2388- "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2389- "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2390- "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2418+ "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2419+ "qwen3-30b-a3b-base" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2420+ "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2421+ "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23912422 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23922423 "gpt-oss-20b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
23932424 "gpt-oss-120b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
0 commit comments