Skip to content

Commit f82cf83

Browse files
Merge pull request #3635 from AI-Hypercomputer:fix_docs_2
PiperOrigin-RevId: 897809712
2 parents b7dbba7 + a74b875 commit f82cf83

6 files changed

Lines changed: 90 additions & 90 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@ This guide provides instructions to use [checkpoint conversion scripts](https://
66

77
The following models are supported:
88

9-
| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF |
10-
| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: |
11-
| **Gemma2** | 2B, 9B, 27B | | | | |
12-
| **Gemma3** (Multimodal) | 4B, 12B, 27B | | | | |
13-
| **Llama3.1** | 8B, 70B, 450B | | | | |
14-
| **Qwen2.5** | 1.5B, 7B, 14B | | | | |
15-
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | | | | |
16-
| **Qwen3 MoE** | 30B, 235B, 480B | | | | |
17-
| **Mixtral** | 8x7B, 8x22B | | | | |
18-
| **GPT-OSS** | 20B, 120B | | | | |
19-
| **DeepSeek2** | 16B | | | | |
20-
| **DeepSeek3** | 671B | | | | |
21-
| **DeepSeek3.2** | 671B | | | - | - |
22-
| **Qwen3 Next** | 80B | | | | |
9+
| Model Family | Sizes | HF $\to$ Orbax (scan) | HF $\to$ Orbax (unscan) | Orbax (scan) $\to$ HF | Orbax (unscan) $\to$ HF |
10+
| :---------------------- | :--------------------- | :-------------------: | :---------------------: | :-------------------: | :---------------------: |
11+
| **Gemma2** | 2B, 9B, 27B |||||
12+
| **Gemma3** (Multimodal) | 4B, 12B, 27B |||||
13+
| **Llama3.1** | 8B, 70B, 450B |||||
14+
| **Qwen2.5** | 1.5B, 7B, 14B |||||
15+
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B |||||
16+
| **Qwen3 MoE** | 30B, 235B, 480B |||||
17+
| **Mixtral** | 8x7B, 8x22B |||||
18+
| **GPT-OSS** | 20B, 120B |||||
19+
| **DeepSeek2** | 16B |||||
20+
| **DeepSeek3** | 671B |||||
21+
| **DeepSeek3.2** | 671B ||| - | - |
22+
| **Qwen3 Next** | 80B |||||
2323

2424
## Prerequisites
2525

docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ The **model FLOPs** are the floating point operations to perform model computati
197197
- The number of model FLOPs is dependent on model architecture, input size (batch size, sequence length), and gradient accumulation steps. It does not include optimization operations.
198198
- We break down the FLOPs into two parts:
199199
- "Learnable weight FLOPs" are matmuls between activations and learnable weights. Specifically, this occurs in embedding, feed forward networks, attention-related projections, and unembedding.
200-
- "Attention FLOPs" are matmuls in attention score computation like $\\mathrm{softmax}{\\left(\\frac{QK^\\top}{\\sqrt{d}}\\right)} V$.
200+
- "Attention FLOPs" are matmuls in attention score computation like $\mathrm{softmax}{\left(\frac{QK^\top}{\sqrt{d}}\right)} V$.
201201

202202
One **TFLOP** (TeraFLOP) is equal to $10^{12}$ FLOPs. The log shows the theoretical estimate of **model TFLOP per device**:
203203

@@ -207,7 +207,7 @@ Per train step:
207207
split as 94.54% learnable weight flops and 5.46% attention flops
208208
```
209209

210-
In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\\text{model tflop per device} \\approx 764.67$.
210+
In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\text{model tflop per device} \approx 764.67$.
211211

212212
- 94.54% of the TFLOPs are attributed to learnable weight and 5.46% are attributed to attention.
213213
- As you will see next, this number is important for calculating performance metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU).
@@ -233,8 +233,8 @@ completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 867
233233

234234
Before we dive deep here, recall a few numbers from previous sections:
235235

236-
- $\\text{max target length} = 2048$, $\\text{per device batch size} = 24$
237-
- $\\text{model tflop per device} \\approx 764.67$ (rounded), $\\text{number of devices} = 4$
236+
- $\text{max target length} = 2048$, $\text{per device batch size} = 24$
237+
- $\text{model tflop per device} \approx 764.67$ (rounded), $\text{number of devices} = 4$
238238

239239
### 4.1. Performance metrics
240240

@@ -244,38 +244,38 @@ The performance metrics fluctuate at the beginning, and become stable towards th
244244
completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 8672.758, total_weights: 196608, loss: 10.374
245245
```
246246

247-
As shown in `seconds: 5.667`, $\\text{measured step time in seconds} \\approx 5.667$ (rounded).
247+
As shown in `seconds: 5.667`, $\text{measured step time in seconds} \approx 5.667$ (rounded).
248248

249249
**TFLOP per second per device**
250250

251251
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L211-L213) as
252252

253-
$$\\text{tflop/s/device} = \\frac{\\text{model tflop per device}}{\\text{measured step time in seconds}}$$
253+
$$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured step time in seconds}}$$
254254

255255
- Here we have `TFLOP/s/device: 134.924`. Let's try to verify manually: $764.67 / 5.667 = 134.934$. Not exactly the same but close, since the both tflop and time are rounded in the log.
256256
- Further, we can calculate **Model FLOPs Utilization (MFU)** from this:
257257

258-
$$\\text{MFU} = \\frac{\\text{tflop/s/device}}{\\text{peak hardware tflop/s}}$$
258+
$$\text{MFU} = \frac{\text{tflop/s/device}}{\text{peak hardware tflop/s}}$$
259259

260-
For TPU v5p, $\\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal.
260+
For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal.
261261

262262
**Tokens per second per device (throughput)**
263263

264264
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L215-L217) as
265265

266-
$$\\text{token/s/device} = \\frac{\\text{number of tokens per device}}{\\text{measured step time in seconds}}$$
266+
$$\text{token/s/device} = \frac{\text{number of tokens per device}}{\text{measured step time in seconds}}$$
267267

268268
- The numerator is from [calculate_tokens_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L148)
269269

270-
$$\\text{number of tokens per device} = \\text{per device batch size} \\times \\text{max target length}$$
270+
$$\text{number of tokens per device} = \text{per device batch size} \times \text{max target length}$$
271271

272-
- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \\times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log.
272+
- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log.
273273

274274
### 4.2. Learning metrics
275275

276276
**Loss**. The loss is the key indicator of learning progress, which should decrease over training steps. In this example, the loss is `12.038` at Step 0 and decreases to `10.374` at Step 9. Ideally, we want the loss to converge to a small value with sufficiently large training steps.
277277

278-
**Total weights**. When discussing the throughput, we have $\\text{number of tokens} = \\text{per device batch size} \\times \\text{max target length} \\times \\text{number of device}$. In this example, $\\text{number of tokens} = 24 \\times 2048 \\times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151).
278+
**Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151).
279279

280280
- Here we see `total_weights: 196608` for all steps. This is because we are using `dataset_type=synthetic`, where all sentences are generated with a length of `max_target_length=2048`. As a result, there are no pad tokens and total weights = number of tokens.
281281
- However, in real datasets, sentences can have variable lengths and total weights < number of tokens. For example, we can set `dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1'`, and will see total weights smaller than `196608`:

0 commit comments

Comments
 (0)