Skip to content

Commit 4a1da22

Browse files
Merge pull request #3632 from AI-Hypercomputer:kimi-k2-release
PiperOrigin-RevId: 899195198
2 parents 5e96345 + fa5b5eb commit 4a1da22

2 files changed

Lines changed: 194 additions & 0 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44+
* \[April 13, 2026\] Kimi-K2 is now supported, along with MuonClip optimizer. Try the [kimi-k2-1t](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/kimi-k2-1t.yml) config and check the [user guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/kimi/Run_Kimi.md).
4445
* \[April 10, 2026\] [DeepSeek-V3.2](https://arxiv.org/pdf/2512.02556) is now supported, featuring DeepSeek Sparse Attention for long context. Try it out with the [deepseek3.2-671b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3.2-671b.yml) config. See the [user guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) for more details.
4546
* \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md).
4647
* \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config.
@@ -120,9 +121,12 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
120121
* Qwen 3 MoE (30B, 235B)
121122
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)
122123
* DeepSeek
124+
* DeepSeek V3.2 (671B)
123125
* DeepSeek V3.1 (671B)
124126
* DeepSeek V3 0324 (671B) & DeepSeek R1 0528 (671B)
125127
* DeepSeek V2 (16B, 236B)
128+
* Kimi
129+
* Kimi K2
126130
* Meta
127131
* Llama 4 Scout (109B) & Maverick (400B)
128132
* Llama 3.3 70B, 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
<!--
2+
# Copyright 2023-2026 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
-->
16+
17+
# Kimi
18+
19+
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**.
20+
21+
* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
22+
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
23+
* **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench.
24+
25+
## Checkpoint Conversion
26+
1. To get started, download the model from HuggingFace: [moonshotai/Kimi-K2-Instruct](https://huggingface.co/moonshotai/Kimi-K2-Instruct). Weights are provided in FP8.
27+
28+
```sh
29+
hf download moonshotai/Kimi-K2-Instruct --local-dir $LOCAL_FP8_PATH
30+
```
31+
32+
2. Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU
33+
```sh
34+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=$LOCAL_FP8_PATH --output-bf16-hf-path=$LOCAL_BF16_PATH
35+
```
36+
Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.
37+
38+
3. To convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html)
39+
- Run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint to scanned format in Orbax for training and fine-tuning.
40+
```sh
41+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE
42+
```
43+
- Run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding.
44+
```sh
45+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE
46+
```
47+
48+
## Pre-training
49+
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale).
50+
51+
```sh
52+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
53+
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
54+
run_name=kimi_k2_pre_training \
55+
per_device_batch_size=16 \
56+
enable_checkpointing=false \
57+
model_name=kimi-k2-1t \
58+
ici_fsdp_parallelism=64 \
59+
ici_expert_parallelism=8 \
60+
steps=5 \
61+
max_target_length=1024 \
62+
async_checkpointing=false \
63+
tokenizer_type=huggingface \
64+
tokenizer_path=moonshotai/Kimi-K2-Instruct \
65+
attention=flash \
66+
use_tokamax_splash=True \
67+
use_tokamax_gmm=False \
68+
dtype=bfloat16 \
69+
weight_dtype=bfloat16 \
70+
megablox=True \
71+
sparse_matmul=True \
72+
dataset_type=synthetic \
73+
scan_layers=True \
74+
use_ring_of_experts=True \
75+
opt_type=muon \
76+
muon_consistent_rms=0.2 \
77+
muon_weight_decay=0.1 \
78+
use_qk_clip=True \
79+
qk_clip_threshold=100
80+
```
81+
82+
## Fine-tuning
83+
After you have a MaxText compatible checkpoint, you can fine-tune Kimi K2. The Kimi team recommends using the **Muon optimizer** during fine-tuning, as it produces the best performance with a Muon-pre-trained checkpoint.
84+
85+
Example command for General Fine-tuning on tpu7x-512:
86+
87+
```sh
88+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
89+
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
90+
run_name=kimi_k2_fine_tuning \
91+
per_device_batch_size=16 \
92+
enable_checkpointing=true \
93+
model_name=kimi-k2-1t \
94+
ici_fsdp_parallelism=64 \
95+
ici_expert_parallelism=8 \
96+
steps=5 \
97+
max_target_length=1024 \
98+
async_checkpointing=false \
99+
tokenizer_type=huggingface \
100+
tokenizer_path=moonshotai/Kimi-K2-Instruct \
101+
attention=flash \
102+
use_tokamax_splash=True \
103+
use_tokamax_gmm=False \
104+
dtype=bfloat16 \
105+
weight_dtype=bfloat16 \
106+
megablox=True \
107+
sparse_matmul=True \
108+
dataset_path=${DATASET_PATH?} \
109+
scan_layers=True \
110+
load_parameters_path=${SCANNED_CHECKPOINT?} \
111+
use_ring_of_experts=True \
112+
opt_type=muon \
113+
muon_consistent_rms=0.2 \
114+
muon_weight_decay=0.1 \
115+
use_qk_clip=True \
116+
qk_clip_threshold=100
117+
```
118+
119+
## Decoding
120+
Example command to run decoding with Kimi K2. Given its 1T size, high tensor parallelism is recommended.
121+
122+
```sh
123+
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
124+
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
125+
load_parameters_path=${CONVERTED_CHECKPOINT?} \
126+
run_name=kimi_decode \
127+
per_device_batch_size=1 \
128+
model_name=kimi-k2-1t \
129+
max_target_length=2048 \
130+
tokenizer_type=huggingface \
131+
tokenizer_path=moonshotai/Kimi-K2-Instruct \
132+
attention=dot_product \
133+
ici_tensor_parallelism=128 \
134+
ici_fsdp_parallelism=1 \
135+
prompt="The primary goal of agentic intelligence is to " \
136+
scan_layers=False
137+
```
138+
139+
## Correctness
140+
141+
To verify the correctness of the Kimi K2 implementation, we perform two primary validation steps:
142+
143+
* **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts.
144+
* **MMLU Score Validation**: We validate the MMLU score against established benchmarks.
145+
146+
### Logit Comparison
147+
148+
Use the following example to generate "golden" logits from the HuggingFace reference model for Kimi K2.
149+
150+
```sh
151+
python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
152+
--model-id=moonshotai/Kimi-K2-Instruct \
153+
--prompts='I love to' \
154+
--output-path=golden_Kimi-K2.jsonl \
155+
--gcs-bucket=$my-gcs-bucket \
156+
--hf-model-path=$LOCAL_BF16_PATH \
157+
--hf-load-dtype=bfloat16 \
158+
--trust-remote-code=True
159+
```
160+
161+
```sh
162+
JAX_PLATFORMS=cpu python3 -m tests.forward_pass_logit_checker \
163+
src/maxtext/configs/base.yml \
164+
base_output_directory=${BASE_OUTPUT_PATH?} \
165+
run_name=forward_logits_check \
166+
model_name=kimi-k2-1t \
167+
load_parameters_path=${UNSCANNED_CKPT_PATH?} \
168+
scan_layers=False \
169+
async_checkpointing=False \
170+
checkpoint_storage_concurrent_gb=1024 \
171+
weight_dtype=bfloat16 \
172+
ici_fsdp_parallelism=1 ici_expert_parallelism=-1 \
173+
attention=dot_product \
174+
per_device_batch_size=1 \
175+
max_prefill_predict_length=4 max_target_length=4 \
176+
sparse_matmul=False \
177+
--golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION?} \
178+
--atol=1.5 --rtol=1.5 --max_kl_div=0.1 \
179+
--skip_first_token \
180+
skip_jax_distributed_system=True
181+
```
182+
183+
To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here]( https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/api_server/README.md).
184+
185+
## Supported MoE strategy
186+
* Dropless
187+
* [MegaBlocks](https://arxiv.org/abs/2211.15841) implementation with flag `sparse_matmul=True megablox=True`.
188+
* [JAX ragged_dot](https://github.com/jax-ml/jax/blob/a8fb0e01f8d083fff337d3c26375bb1b77344a99/jax/_src/lax/lax.py#L2415) implementation with flag `sparse_matmul=True megablox=False`.
189+
* General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`.
190+
* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25.

0 commit comments

Comments
 (0)