|
| 1 | +<!-- |
| 2 | + # Copyright 2023-2026 Google LLC |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + --> |
| 16 | + |
| 17 | +# Kimi |
| 18 | + |
| 19 | +Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**. |
| 20 | + |
| 21 | +* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks. |
| 22 | +* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training. |
| 23 | +* **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench. |
| 24 | + |
| 25 | +## Checkpoint Conversion |
| 26 | +1. To get started, download the model from HuggingFace: [moonshotai/Kimi-K2-Instruct](https://huggingface.co/moonshotai/Kimi-K2-Instruct). Weights are provided in FP8. |
| 27 | + |
| 28 | +```sh |
| 29 | +hf download moonshotai/Kimi-K2-Instruct --local-dir $LOCAL_FP8_PATH |
| 30 | +``` |
| 31 | + |
| 32 | +2. Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU |
| 33 | +```sh |
| 34 | +python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=$LOCAL_FP8_PATH --output-bf16-hf-path=$LOCAL_BF16_PATH |
| 35 | +``` |
| 36 | +Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU. |
| 37 | + |
| 38 | +3. To convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) |
| 39 | +- Run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint to scanned format in Orbax for training and fine-tuning. |
| 40 | +```sh |
| 41 | +python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE |
| 42 | +``` |
| 43 | +- Run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding. |
| 44 | +```sh |
| 45 | +python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE |
| 46 | +``` |
| 47 | + |
| 48 | +## Pre-training |
| 49 | +You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). |
| 50 | + |
| 51 | +```sh |
| 52 | +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ |
| 53 | + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ |
| 54 | + run_name=kimi_k2_pre_training \ |
| 55 | + per_device_batch_size=16 \ |
| 56 | + enable_checkpointing=false \ |
| 57 | + model_name=kimi-k2-1t \ |
| 58 | + ici_fsdp_parallelism=64 \ |
| 59 | + ici_expert_parallelism=8 \ |
| 60 | + steps=5 \ |
| 61 | + max_target_length=1024 \ |
| 62 | + async_checkpointing=false \ |
| 63 | + tokenizer_type=huggingface \ |
| 64 | + tokenizer_path=moonshotai/Kimi-K2-Instruct \ |
| 65 | + attention=flash \ |
| 66 | + use_tokamax_splash=True \ |
| 67 | + use_tokamax_gmm=False \ |
| 68 | + dtype=bfloat16 \ |
| 69 | + weight_dtype=bfloat16 \ |
| 70 | + megablox=True \ |
| 71 | + sparse_matmul=True \ |
| 72 | + dataset_type=synthetic \ |
| 73 | + scan_layers=True \ |
| 74 | + use_ring_of_experts=True \ |
| 75 | + opt_type=muon \ |
| 76 | + muon_consistent_rms=0.2 \ |
| 77 | + muon_weight_decay=0.1 \ |
| 78 | + use_qk_clip=True \ |
| 79 | + qk_clip_threshold=100 |
| 80 | +``` |
| 81 | + |
| 82 | +## Fine-tuning |
| 83 | +After you have a MaxText compatible checkpoint, you can fine-tune Kimi K2. The Kimi team recommends using the **Muon optimizer** during fine-tuning, as it produces the best performance with a Muon-pre-trained checkpoint. |
| 84 | + |
| 85 | +Example command for General Fine-tuning on tpu7x-512: |
| 86 | + |
| 87 | +```sh |
| 88 | +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ |
| 89 | + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ |
| 90 | + run_name=kimi_k2_fine_tuning \ |
| 91 | + per_device_batch_size=16 \ |
| 92 | + enable_checkpointing=true \ |
| 93 | + model_name=kimi-k2-1t \ |
| 94 | + ici_fsdp_parallelism=64 \ |
| 95 | + ici_expert_parallelism=8 \ |
| 96 | + steps=5 \ |
| 97 | + max_target_length=1024 \ |
| 98 | + async_checkpointing=false \ |
| 99 | + tokenizer_type=huggingface \ |
| 100 | + tokenizer_path=moonshotai/Kimi-K2-Instruct \ |
| 101 | + attention=flash \ |
| 102 | + use_tokamax_splash=True \ |
| 103 | + use_tokamax_gmm=False \ |
| 104 | + dtype=bfloat16 \ |
| 105 | + weight_dtype=bfloat16 \ |
| 106 | + megablox=True \ |
| 107 | + sparse_matmul=True \ |
| 108 | + dataset_path=${DATASET_PATH?} \ |
| 109 | + scan_layers=True \ |
| 110 | + load_parameters_path=${SCANNED_CHECKPOINT?} \ |
| 111 | + use_ring_of_experts=True \ |
| 112 | + opt_type=muon \ |
| 113 | + muon_consistent_rms=0.2 \ |
| 114 | + muon_weight_decay=0.1 \ |
| 115 | + use_qk_clip=True \ |
| 116 | + qk_clip_threshold=100 |
| 117 | +``` |
| 118 | + |
| 119 | +## Decoding |
| 120 | +Example command to run decoding with Kimi K2. Given its 1T size, high tensor parallelism is recommended. |
| 121 | + |
| 122 | +```sh |
| 123 | +python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \ |
| 124 | + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ |
| 125 | + load_parameters_path=${CONVERTED_CHECKPOINT?} \ |
| 126 | + run_name=kimi_decode \ |
| 127 | + per_device_batch_size=1 \ |
| 128 | + model_name=kimi-k2-1t \ |
| 129 | + max_target_length=2048 \ |
| 130 | + tokenizer_type=huggingface \ |
| 131 | + tokenizer_path=moonshotai/Kimi-K2-Instruct \ |
| 132 | + attention=dot_product \ |
| 133 | + ici_tensor_parallelism=128 \ |
| 134 | + ici_fsdp_parallelism=1 \ |
| 135 | + prompt="The primary goal of agentic intelligence is to " \ |
| 136 | + scan_layers=False |
| 137 | +``` |
| 138 | + |
| 139 | +## Correctness |
| 140 | + |
| 141 | +To verify the correctness of the Kimi K2 implementation, we perform two primary validation steps: |
| 142 | + |
| 143 | + * **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts. |
| 144 | + * **MMLU Score Validation**: We validate the MMLU score against established benchmarks. |
| 145 | + |
| 146 | +### Logit Comparison |
| 147 | + |
| 148 | +Use the following example to generate "golden" logits from the HuggingFace reference model for Kimi K2. |
| 149 | + |
| 150 | +```sh |
| 151 | +python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ |
| 152 | + --model-id=moonshotai/Kimi-K2-Instruct \ |
| 153 | + --prompts='I love to' \ |
| 154 | + --output-path=golden_Kimi-K2.jsonl \ |
| 155 | + --gcs-bucket=$my-gcs-bucket \ |
| 156 | + --hf-model-path=$LOCAL_BF16_PATH \ |
| 157 | + --hf-load-dtype=bfloat16 \ |
| 158 | + --trust-remote-code=True |
| 159 | +``` |
| 160 | + |
| 161 | +```sh |
| 162 | +JAX_PLATFORMS=cpu python3 -m tests.forward_pass_logit_checker \ |
| 163 | + src/maxtext/configs/base.yml \ |
| 164 | + base_output_directory=${BASE_OUTPUT_PATH?} \ |
| 165 | + run_name=forward_logits_check \ |
| 166 | + model_name=kimi-k2-1t \ |
| 167 | + load_parameters_path=${UNSCANNED_CKPT_PATH?} \ |
| 168 | + scan_layers=False \ |
| 169 | + async_checkpointing=False \ |
| 170 | + checkpoint_storage_concurrent_gb=1024 \ |
| 171 | + weight_dtype=bfloat16 \ |
| 172 | + ici_fsdp_parallelism=1 ici_expert_parallelism=-1 \ |
| 173 | + attention=dot_product \ |
| 174 | + per_device_batch_size=1 \ |
| 175 | + max_prefill_predict_length=4 max_target_length=4 \ |
| 176 | + sparse_matmul=False \ |
| 177 | + --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION?} \ |
| 178 | + --atol=1.5 --rtol=1.5 --max_kl_div=0.1 \ |
| 179 | + --skip_first_token \ |
| 180 | + skip_jax_distributed_system=True |
| 181 | +``` |
| 182 | + |
| 183 | +To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here]( https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/api_server/README.md). |
| 184 | + |
| 185 | +## Supported MoE strategy |
| 186 | +* Dropless |
| 187 | + * [MegaBlocks](https://arxiv.org/abs/2211.15841) implementation with flag `sparse_matmul=True megablox=True`. |
| 188 | + * [JAX ragged_dot](https://github.com/jax-ml/jax/blob/a8fb0e01f8d083fff337d3c26375bb1b77344a99/jax/_src/lax/lax.py#L2415) implementation with flag `sparse_matmul=True megablox=False`. |
| 189 | + * General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`. |
| 190 | +* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25. |
0 commit comments