Skip to content

Commit cd7a1eb

Browse files
Merge pull request #3184 from AI-Hypercomputer:shuningjin-ckpt-opt3
PiperOrigin-RevId: 890630208
2 parents 61dc465 + b1a5feb commit cd7a1eb

14 files changed

Lines changed: 521 additions & 198 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ The following models are supported:
1616
| **Qwen3 MoE** | 30B, 235B, 480B |||||
1717
| **Mixtral** | 8x7B, 8x22B |||||
1818
| **GPT-OSS** | 20B, 120B |||||
19-
| **DeepSeek3** | 671B | - | - || - |
19+
| **DeepSeek2** | 16B |||||
20+
| **DeepSeek3** | 671B |||||
21+
| **DeepSeek3.2** | 671B ||| - | - |
2022
| **Qwen3 Next** | 80B |||||
2123

2224
## Prerequisites
@@ -60,7 +62,8 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
6062
skip_jax_distributed_system=true \
6163
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
6264
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
63-
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
65+
--lazy_load_tensors=${LAZY_LOAD_TENSORS?} \
66+
--save_dtype=bfloat16
6467
```
6568

6669
You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/items`.
@@ -74,7 +77,8 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
7477
- `hardware=cpu`: The conversion script runs on a CPU machine.
7578
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False.
7679
- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
77-
- `--hf_model_path` (Optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
80+
- `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
81+
- `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory.
7882

7983
## MaxText to Hugging Face
8084

@@ -118,7 +122,7 @@ python3 -m maxtext.checkpoint_conversion.to_huggingface \
118122
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
119123
- `hardware=cpu`: The conversion script runs on a CPU machine.
120124
- `base_output_directory`: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local.
121-
- `weight_dtype`: dtype for MaxText weights. It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.
125+
- `weight_dtype`: It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.
122126

123127
## Verifying conversion correctness
124128

@@ -226,7 +230,7 @@ To extend conversion support to a new model architecture, you must define its sp
226230

227231
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
228232

229-
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
233+
2. **Add Hugging Face weights Shape**: In [`utils/globals.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
230234

231235
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
232236

src/maxtext/checkpoint_conversion/compare_hf_ckpt.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from safetensors import safe_open
4949

5050
from maxtext.configs import pyconfig
51-
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model
51+
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, load_hf_dict_from_transformers
5252
from maxtext.utils import max_logging
5353
from maxtext.utils.globals import HF_IDS
5454

@@ -135,8 +135,7 @@ def get_hf_model_state_dict(model_id: str, token: str) -> Dict[str, np.ndarray]:
135135
"""Loads the HuggingFace model state dict and converts to numpy."""
136136
max_logging.log(f"Loading reference model from HuggingFace: {model_id}...")
137137

138-
hf_model = get_hf_model(model_id, token)
139-
state_dict = hf_model.state_dict()
138+
state_dict = load_hf_dict_from_transformers(model_id, token)
140139
numpy_state_dict = {k: v.numpy() for k, v in state_dict.items()}
141140

142141
return numpy_state_dict
@@ -261,12 +260,9 @@ def main(args: Sequence[str], test_args: argparse.Namespace) -> None:
261260
help="Absolute tolerance for numpy.allclose",
262261
)
263262

264-
local_args, _ = parser.parse_known_args()
265263
logging.set_verbosity(logging.INFO)
266264

267-
# Filter args for MaxText config parsing
268-
model_args = sys.argv
269-
to_remove_args = ["--candidate_path", "--reference_path", "--max_workers", "--rtol", "--atol"]
270-
model_args = [s for s in model_args if not any(s.startswith(a) for a in to_remove_args)]
265+
local_args, remaining_args = parser.parse_known_args()
266+
model_args = [sys.argv[0]] + remaining_args
271267

272268
main(model_args, local_args)

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,8 @@ def shard_checkpoint(jax_weights, device_count, mem_info):
16491649
"WARNING: hardware/simulated device mismatch. "
16501650
f"Actual JAX devices: {len(jax.devices())}, Requested count: {device_count}."
16511651
)
1652-
max_logging.log(f"shard weights across {len(jax.devices())} devices")
1652+
max_logging.log(f"Shard weights across {len(jax.devices())} devices")
1653+
max_logging.log("Note: Axis 0 sharding is the default and will not be logged individually.")
16531654
# Pre-define sharding specs
16541655
mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
16551656
# Sharding along axis 0
@@ -1673,13 +1674,13 @@ def checkpoint_device_put(arr):
16731674
arr = np.array(arr)
16741675

16751676
if arr.shape[0] % device_count == 0:
1676-
max_logging.log("sharding axis 0")
1677+
# Sharding axis 0: Omit log for brevity per the summary log above.
16771678
return jax.device_put(arr, device=s1)
16781679
elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0:
1679-
max_logging.log("sharding axis 1")
1680+
max_logging.log(f"Sharding axis 1. Tensor shape {arr.shape}")
16801681
return jax.device_put(arr, device=s2)
16811682
else:
1682-
max_logging.log("no sharding was possible, replicating")
1683+
max_logging.log(f"Not sharding. Tensor shape {arr.shape}")
16831684
return jax.device_put(arr, device=s3)
16841685

16851686
# Weight sharding

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
Defaults to "./mt_output/".
2929
scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
3030
This must match the training configuration of the checkpoint.
31+
weight_dtype: (Optional) It affects the resulting Hugging Face weight dtype.
32+
Default value is `float32`. We recommend using `bfloat16`
33+
to save memory and speed up conversion.
3134
3235
Optional Flags:
3336
--override_model_architecture: If set, overrides the HF model configuration
@@ -139,13 +142,25 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
139142
attributes_to_check = [
140143
("num_attention_heads", "num_query_heads"),
141144
("num_key_value_heads", "num_kv_heads"),
142-
("head_dim", "head_dim"),
143145
("hidden_size", "emb_dim"),
144146
("intermediate_size", "mlp_dim"),
145147
("num_hidden_layers", "num_decoder_layers"),
146148
("vocab_size", "vocab_size"),
147149
]
148150

151+
if max_config.attention_type == "mla":
152+
attributes_to_check.extend(
153+
[
154+
("qk_nope_head_dim", "qk_nope_head_dim"),
155+
("qk_rope_head_dim", "qk_rope_head_dim"),
156+
("v_head_dim", "v_head_dim"),
157+
("kv_lora_rank", "kv_lora_rank"),
158+
("q_lora_rank", "q_lora_rank"),
159+
]
160+
)
161+
else:
162+
attributes_to_check.append(("head_dim", "head_dim"))
163+
149164
mismatches = []
150165

151166
for hf_attr, mt_attr in attributes_to_check:
@@ -215,6 +230,7 @@ def main(argv: Sequence[str]) -> None:
215230
checkpoint_dict = load_orbax_checkpoint(config)
216231
max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min")
217232

233+
# Define output directory
218234
if not config.base_output_directory:
219235
output_directory = f"tmp/{config.run_name}"
220236
else:
@@ -269,6 +285,8 @@ def main(argv: Sequence[str]) -> None:
269285
processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
270286
processed_params_list.extend(processed_params)
271287

288+
max_logging.log(f"Weight dtype after transform: {type(processed_params[0][1].dtype)}")
289+
272290
transformed_hf_weights = dict(processed_params_list)
273291
max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min")
274292

0 commit comments

Comments
 (0)