Skip to content

Commit cdc587f

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
Add support for Gemma 4 architectures in MaxText.
This PR brings up the Gemma 4 architectures (31B Dense, 26B MoE) for both text and vision modalities. Key changes: * Adds the core Gemma 4 text and vision architectures. * Introduces the specialized layers and building blocks required for Gemma 4. * Adds the necessary configurations and pipeline logic to support checkpoint conversion. PiperOrigin-RevId: 893570076
1 parent 612162a commit cdc587f

35 files changed

Lines changed: 3204 additions & 127 deletions

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44+
* \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md).
4445
* \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config.
4546
* \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info.
4647
* \[March 5, 2026\] [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/7656eb8d1c9eb0dd91e617a6fdf6ad805221221a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported.
@@ -73,7 +74,7 @@ MaxText provides a library of models and demonstrates how to perform pre-trainin
7374

7475
MaxText leverages [JAX AI libraries](https://docs.jaxstack.ai/en/latest/getting_started.html) and presents a cohesive and comprehensive demonstration of training at scale by using [Flax](https://flax.readthedocs.io/en/latest/) (neural networks), [Tunix](https://github.com/google/tunix) (post-training), [Orbax](https://orbax.readthedocs.io/en/latest/) (checkpointing), [Optax](https://optax.readthedocs.io/en/latest/) (optimization), and [Grain](https://google-grain.readthedocs.io/en/latest/) (dataloading).
7576

76-
In addition to pure text-based LLMs, we also support multi-modal training with Gemma 3 and Llama 4 VLMs.
77+
In addition to pure text-based LLMs, we also support multi-modal training with Gemma 3, Gemma 4, and Llama 4 VLMs.
7778

7879
### Pre-training
7980

@@ -103,6 +104,7 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
103104
**Supported JAX models in MaxText**
104105

105106
* Google
107+
* Gemma 4 (26B MoE, 31B Dense)
106108
* Gemma 3 (4B, 12B, 27B)
107109
* Gemma 2 (2B, 9B, 27B)
108110
* Gemma 1 (2B, 7B)

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,8 @@ def shard_checkpoint(jax_weights, device_count, mem_info):
16531653
max_logging.log("Note: Axis 0 sharding is the default and will not be logged individually.")
16541654
# Pre-define sharding specs
16551655
mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
1656+
# No sharding (replicated specifically for 0D scalars)
1657+
s0 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
16561658
# Sharding along axis 0
16571659
s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis"))
16581660
# Sharding along axis 1
@@ -1673,7 +1675,10 @@ def checkpoint_device_put(arr):
16731675
# materialize lazy tensor
16741676
arr = np.array(arr)
16751677

1676-
if arr.shape[0] % device_count == 0:
1678+
if len(arr.shape) == 0:
1679+
max_logging.log("0D scalar detected, replicating")
1680+
return jax.device_put(arr, device=s0)
1681+
elif arr.shape[0] % device_count == 0:
16771682
# Sharding axis 0: Omit log for brevity per the summary log above.
16781683
return jax.device_put(arr, device=s1)
16791684
elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0:

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 156 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,137 @@
2424
else:
2525
from transformers.configuration_utils import PretrainedConfig as PTConfig
2626

27+
28+
gemma4_26b_dict = {
29+
"architectures": ["Gemma4ForConditionalGeneration"],
30+
"audio_config": None,
31+
"audio_token_id": 258881,
32+
"boa_token_id": 256000,
33+
"boi_token_id": 255999,
34+
"dtype": "bfloat16",
35+
"eoa_token_id": 258883,
36+
"eoa_token_index": 258883,
37+
"eoi_token_id": 258882,
38+
"eos_token_id": [1, 106],
39+
"image_token_id": 258880,
40+
"initializer_range": 0.02,
41+
"model_type": "gemma4",
42+
"text_config": {
43+
"attention_bias": False,
44+
"attention_dropout": 0.0,
45+
"attention_k_eq_v": True,
46+
"bos_token_id": 2,
47+
"dtype": "bfloat16",
48+
"enable_moe_block": True,
49+
"eos_token_id": 1,
50+
"expert_intermediate_size": 704,
51+
"final_logit_softcapping": 30.0,
52+
"global_head_dim": 512,
53+
"head_dim": 256,
54+
"hidden_activation": "gelu_pytorch_tanh",
55+
"hidden_size": 2816,
56+
"hidden_size_per_layer_input": 0,
57+
"initializer_range": 0.02,
58+
"intermediate_size": 2112,
59+
"layer_types": [
60+
"sliding_attention",
61+
"sliding_attention",
62+
"sliding_attention",
63+
"sliding_attention",
64+
"sliding_attention",
65+
"full_attention",
66+
]
67+
* 5,
68+
"max_position_embeddings": 262144,
69+
"model_type": "gemma4_text",
70+
"num_attention_heads": 16,
71+
"num_experts": 128,
72+
"num_global_key_value_heads": 2,
73+
"num_hidden_layers": 30,
74+
"num_key_value_heads": 8,
75+
"num_kv_shared_layers": 0,
76+
"pad_token_id": 0,
77+
"rms_norm_eps": 1e-06,
78+
"rope_parameters": {
79+
"full_attention": {"partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0, "rope_type": "proportional"},
80+
"sliding_attention": {"rope_theta": 10_000.0, "rope_type": "default"},
81+
},
82+
"sliding_window": 1024,
83+
"tie_word_embeddings": True,
84+
"top_k_experts": 8,
85+
"use_bidirectional_attention": "vision",
86+
"use_cache": True,
87+
"use_double_wide_mlp": False,
88+
"vocab_size": 262144,
89+
"vocab_size_per_layer_input": 262144,
90+
},
91+
"tie_word_embeddings": True,
92+
"transformers_version": "5.5.0.dev0",
93+
"video_token_id": 258884,
94+
"vision_config": {
95+
"attention_bias": False,
96+
"attention_dropout": 0.0,
97+
"default_output_length": 280,
98+
"dtype": "bfloat16",
99+
"global_head_dim": 72,
100+
"head_dim": 72,
101+
"hidden_activation": "gelu_pytorch_tanh",
102+
"hidden_size": 1152,
103+
"intermediate_size": 4304,
104+
"max_position_embeddings": 131072,
105+
"model_type": "gemma4_vision",
106+
"num_attention_heads": 16,
107+
"num_hidden_layers": 27,
108+
"num_key_value_heads": 16,
109+
"patch_size": 16,
110+
"pooling_kernel_size": 3,
111+
"position_embedding_size": 10240,
112+
"rms_norm_eps": 1e-06,
113+
"rope_parameters": {"rope_theta": 100.0, "rope_type": "default"},
114+
"standardize": True,
115+
"use_clipped_linears": False,
116+
},
117+
"vision_soft_tokens_per_image": 280,
118+
}
119+
120+
121+
gemma4_31b_dict = gemma4_26b_dict.copy()
122+
gemma4_31b_dict["text_config"] = gemma4_26b_dict["text_config"].copy()
123+
gemma4_31b_dict["text_config"].update(
124+
{
125+
"enable_moe_block": False,
126+
"expert_intermediate_size": None,
127+
"hidden_size": 5376,
128+
"intermediate_size": 21504,
129+
"layer_types": [
130+
"sliding_attention",
131+
"sliding_attention",
132+
"sliding_attention",
133+
"sliding_attention",
134+
"sliding_attention",
135+
"full_attention",
136+
]
137+
* 10,
138+
"num_attention_heads": 32,
139+
"num_experts": None,
140+
"num_global_key_value_heads": 4,
141+
"num_hidden_layers": 60,
142+
"num_key_value_heads": 16,
143+
"top_k_experts": None,
144+
}
145+
)
146+
147+
148+
try:
149+
# Will execute successfully if Transformers is updated with Gemma 4 support
150+
gemma4_26b_config = transformers.Gemma4Config(**gemma4_26b_dict)
151+
gemma4_31b_config = transformers.Gemma4Config(**gemma4_31b_dict)
152+
except AttributeError:
153+
# Graceful fallback to raw dict-based PTConfig if Gemma 4 natively is missing
154+
gemma4_26b_config = PTConfig(**gemma4_26b_dict)
155+
gemma4_31b_config = PTConfig(**gemma4_31b_dict)
156+
157+
27158
gemma3_4b_config = transformers.Gemma3Config(
28159
architectures=["Gemma3ForConditionalGeneration"],
29160
boi_token_index=255999,
@@ -584,9 +715,10 @@
584715
"mscale": 0.707,
585716
"mscale_all_dim": 0.707,
586717
"original_max_position_embeddings": 4096,
718+
"rope_theta": 10_000,
587719
"type": "yarn",
588720
},
589-
"rope_theta": 10000,
721+
"rope_theta": 10_000,
590722
"routed_scaling_factor": 1.0,
591723
"scoring_func": "softmax",
592724
"seq_aux": True,
@@ -645,9 +777,10 @@
645777
"mscale": 1.0,
646778
"mscale_all_dim": 1.0,
647779
"original_max_position_embeddings": 4096,
780+
"rope_theta": 10_000,
648781
"type": "yarn",
649782
},
650-
"rope_theta": 10000,
783+
"rope_theta": 10_000,
651784
"routed_scaling_factor": 2.5,
652785
"scoring_func": "sigmoid",
653786
"tie_word_embeddings": False,
@@ -697,15 +830,16 @@
697830
"qk_rope_head_dim": 64,
698831
"rms_norm_eps": 1e-06,
699832
"rope_scaling": {
700-
"beta_fast": 32,
701-
"beta_slow": 1,
702-
"factor": 40,
833+
"beta_fast": 32.0,
834+
"beta_slow": 1.0,
835+
"factor": 40.0,
703836
"mscale": 1.0,
704837
"mscale_all_dim": 1.0,
705838
"original_max_position_embeddings": 4096,
839+
"rope_theta": 10_000,
706840
"type": "yarn",
707841
},
708-
"rope_theta": 10000,
842+
"rope_theta": 10_000,
709843
"routed_scaling_factor": 2.5,
710844
"scoring_func": "sigmoid",
711845
"tie_word_embeddings": False,
@@ -717,8 +851,17 @@
717851
"v_head_dim": 128,
718852
"vocab_size": 129280,
719853
}
854+
855+
720856
# TODO(shuningjin): replace with DeepseekV32Config when available in transformers library
721-
deepseek32_671b_config = PTConfig(**deepseek32_671b_dict)
857+
class DeepseekV32Config(PTConfig):
858+
859+
def __init__(self, **kwargs):
860+
self.max_position_embeddings = kwargs.get("max_position_embeddings", 163840)
861+
super().__init__(**kwargs)
862+
863+
864+
deepseek32_671b_config = DeepseekV32Config(**deepseek32_671b_dict)
722865

723866
# from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json
724867
# remove mxfp4 quantization_config, since we are using bf16
@@ -775,10 +918,11 @@
775918
"beta_slow": 1.0,
776919
"factor": 32.0,
777920
"original_max_position_embeddings": 4096,
921+
"rope_theta": 150_000,
778922
"rope_type": "yarn",
779923
"truncate": False,
780924
},
781-
"rope_theta": 150000,
925+
"rope_theta": 150_000,
782926
"router_aux_loss_coef": 0.9,
783927
"sliding_window": 128,
784928
"swiglu_limit": 7.0,
@@ -856,10 +1000,11 @@
8561000
"beta_slow": 1.0,
8571001
"factor": 32.0,
8581002
"original_max_position_embeddings": 4096,
1003+
"rope_theta": 150_000,
8591004
"rope_type": "yarn",
8601005
"truncate": False,
8611006
},
862-
"rope_theta": 150000,
1007+
"rope_theta": 150_000,
8631008
"router_aux_loss_coef": 0.9,
8641009
"sliding_window": 128,
8651010
"swiglu_limit": 7.0,
@@ -1006,6 +1151,8 @@
10061151
"gemma3-4b": gemma3_4b_config,
10071152
"gemma3-12b": gemma3_12b_config,
10081153
"gemma3-27b": gemma3_27b_config,
1154+
"gemma4-26b": gemma4_26b_config,
1155+
"gemma4-31b": gemma4_31b_config,
10091156
"qwen2.5-1.5b": qwen25_1_5b_config,
10101157
"qwen2.5-7b": qwen25_7b_config,
10111158
"qwen2.5-14b": qwen25_14b_config,

0 commit comments

Comments
 (0)