Skip to content

Commit 447a940

Browse files
Merge pull request #3554 from AI-Hypercomputer:agagik-gemma4-moe
PiperOrigin-RevId: 893708237
2 parents 0ee5a04 + bfc1e43 commit 447a940

5 files changed

Lines changed: 32 additions & 30 deletions

File tree

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,16 +2362,16 @@ def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
23622362
f"{text_base}.layers.{i}.router.proj.weight" if num_experts > 1 else None for i in hf_indices
23632363
],
23642364
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": [
2365-
f"{text_base}.layers.{i}.moe.gate_up_proj" if num_experts > 1 else None for i in hf_indices
2365+
f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices
23662366
],
23672367
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": [
2368-
f"{text_base}.layers.{i}.moe.gate_up_proj" if num_experts > 1 else None for i in hf_indices
2368+
f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices
23692369
],
23702370
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": [
2371-
f"{text_base}.layers.{i}.moe.down_proj" if num_experts > 1 else None for i in hf_indices
2371+
f"{text_base}.layers.{i}.experts.down_proj" if num_experts > 1 else None for i in hf_indices
23722372
],
23732373
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": [
2374-
f"{text_base}.layers.{i}.moe.per_expert_scale" if num_experts > 1 else None for i in hf_indices
2374+
f"{text_base}.layers.{i}.router.per_expert_scale" if num_experts > 1 else None for i in hf_indices
23752375
],
23762376
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": [
23772377
f"{text_base}.layers.{i}.mlp.gate_proj.weight" if num_experts > 1 else None for i in hf_indices
@@ -2440,10 +2440,14 @@ def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
24402440
f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight"
24412441
if num_experts > 1
24422442
else None,
2443-
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.moe.gate_up_proj" if num_experts > 1 else None,
2444-
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.moe.gate_up_proj" if num_experts > 1 else None,
2445-
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.moe.down_proj" if num_experts > 1 else None,
2446-
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.moe.per_expert_scale"
2443+
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj"
2444+
if num_experts > 1
2445+
else None,
2446+
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj"
2447+
if num_experts > 1
2448+
else None,
2449+
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None,
2450+
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale"
24472451
if num_experts > 1
24482452
else None,
24492453
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight"
@@ -2502,10 +2506,10 @@ def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
25022506
f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight"
25032507
if num_experts > 1
25042508
else None,
2505-
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.moe.gate_up_proj" if num_experts > 1 else None,
2506-
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.moe.gate_up_proj" if num_experts > 1 else None,
2507-
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.moe.down_proj" if num_experts > 1 else None,
2508-
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.moe.per_expert_scale"
2509+
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None,
2510+
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None,
2511+
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None,
2512+
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale"
25092513
if num_experts > 1
25102514
else None,
25112515
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight"

tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ MODEL_NAME='gemma4-26b'
77
export MODEL_VARIATION='26b'
88
TOKENIZER_PATH='google/gemma-4-26b-a4b-it'
99
# To convert the multimodal model, make sure the use_multimodal is set to be true
10-
USE_MULTIMODAL=true
10+
USE_MULTIMODAL=false
1111
USE_SCAN_LAYERS=false
1212

1313

1414
# Installing torch for deps in forward_pass_logit_checker.py
1515
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1616

17-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET
17+
# After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET
1818
export MODEL_BUCKET='gs://maxtext-gemma/gemma4'
19+
export HF_MODEL='path/to/your/hf/gemma-4-26b-a4b-it'
1920

2021
# To get converted ckpt:
21-
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
22+
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
2223
model_name=${MODEL_NAME} \
2324
hf_access_token=${HF_TOKEN} \
2425
--hf_model_path=${HF_MODEL} \
@@ -28,7 +29,6 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MA
2829

2930

3031
export MAXTEXT_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/${idx}/0/items
31-
export HF_MODEL='path/to/your/hf/gemma-4-26b-a4b-it'
3232

3333

3434
if [ ${USE_MULTIMODAL} == true ]; then
@@ -62,7 +62,7 @@ if [ ${USE_MULTIMODAL} == true ]; then
6262
--max_kl_div=0.03 \
6363
--golden_logits_path=${GOLDEN_LOGITS_PATH}
6464
else
65-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
65+
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
6666
tokenizer_path=${TOKENIZER_PATH} \
6767
load_parameters_path=${MAXTEXT_CKPT_PATH} \
6868
model_name=${MODEL_NAME} \

tests/end_to_end/tpu/gemma4/26b/convert_gemma4_pt.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ MODEL_NAME='gemma4-26b'
77
export MODEL_VARIATION='26b'
88
TOKENIZER_PATH='google/gemma-4-26b-a4b'
99
# To convert the multimodal model, make sure the use_multimodal is set to be true
10-
USE_MULTIMODAL=true
10+
USE_MULTIMODAL=false
1111
USE_SCAN_LAYERS=false
1212

1313
# Installing torch for deps in forward_pass_logit_checker.py
1414
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1515

16-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET
16+
# After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET
1717
export MODEL_BUCKET='gs://maxtext-gemma/gemma4'
18-
export HF_MODEL='path/to/your/gemma4-26b-a4b'
18+
export HF_MODEL='path/to/your/hf/gemma-4-26b-a4b'
1919

2020
# To get converted ckpt:
21-
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
21+
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
2222
model_name=${MODEL_NAME} \
2323
hf_access_token=${HF_TOKEN} \
2424
--hf_model_path=${HF_MODEL} \
@@ -61,7 +61,7 @@ if [ ${USE_MULTIMODAL} == true ]; then
6161
--max_kl_div=0.03 \
6262
--golden_logits_path=${GOLDEN_LOGITS_PATH}
6363
else
64-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
64+
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
6565
tokenizer_path=${TOKENIZER_PATH} \
6666
load_parameters_path=${MAXTEXT_CKPT_PATH} \
6767
model_name=${MODEL_NAME} \

tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ USE_SCAN_LAYERS=false
1414
# Installing torch for deps in forward_pass_logit_checker.py
1515
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1616

17-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET
17+
# After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET
1818
export MODEL_BUCKET='gs://maxtext-gemma/gemma4'
19-
2019
export HF_MODEL='path/to/your/hf/gemma-4-31b-it'
2120

2221
# To get converted ckpt:
23-
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
22+
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
2423
model_name=${MODEL_NAME} \
2524
hf_access_token=${HF_TOKEN} \
2625
--hf_model_path=${HF_MODEL} \
@@ -63,7 +62,7 @@ if [ ${USE_MULTIMODAL} == true ]; then
6362
--max_kl_div=0.03 \
6463
--golden_logits_path=${GOLDEN_LOGITS_PATH}
6564
else
66-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
65+
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
6766
tokenizer_path=${TOKENIZER_PATH} \
6867
load_parameters_path=${MAXTEXT_CKPT_PATH} \
6968
model_name=${MODEL_NAME} \

tests/end_to_end/tpu/gemma4/31b/convert_gemma4_pt.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ USE_SCAN_LAYERS=false
1414
# Installing torch for deps in forward_pass_logit_checker.py
1515
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1616

17-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
17+
# After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET
1818
export MODEL_BUCKET='gs://maxtext-gemma/gemma4'
19-
2019
export HF_MODEL='path/to/your/hf/gemma-4-31b'
2120

2221
# To get converted ckpt:
23-
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
22+
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
2423
model_name=${MODEL_NAME} \
2524
hf_access_token=${HF_TOKEN} \
2625
--hf_model_path=${HF_MODEL} \
@@ -63,7 +62,7 @@ if [ ${USE_MULTIMODAL} == true ]; then
6362
--max_kl_div=0.03 \
6463
--golden_logits_path=${GOLDEN_LOGITS_PATH}
6564
else
66-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
65+
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
6766
tokenizer_path=${TOKENIZER_PATH} \
6867
load_parameters_path=${MAXTEXT_CKPT_PATH} \
6968
model_name=${MODEL_NAME} \

0 commit comments

Comments
 (0)