You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md
+12-12Lines changed: 12 additions & 12 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -197,7 +197,7 @@ The **model FLOPs** are the floating point operations to perform model computati
197
197
- The number of model FLOPs is dependent on model architecture, input size (batch size, sequence length), and gradient accumulation steps. It does not include optimization operations.
198
198
- We break down the FLOPs into two parts:
199
199
- "Learnable weight FLOPs" are matmuls between activations and learnable weights. Specifically, this occurs in embedding, feed forward networks, attention-related projections, and unembedding.
200
-
- "Attention FLOPs" are matmuls in attention score computation like $\\mathrm{softmax}{\\left(\\frac{QK^\\top}{\\sqrt{d}}\\right)} V$.
200
+
- "Attention FLOPs" are matmuls in attention score computation like $\mathrm{softmax}{\left(\frac{QK^\top}{\sqrt{d}}\right)} V$.
201
201
202
202
One **TFLOP** (TeraFLOP) is equal to $10^{12}$ FLOPs. The log shows the theoretical estimate of **model TFLOP per device**:
203
203
@@ -207,7 +207,7 @@ Per train step:
207
207
split as 94.54% learnable weight flops and 5.46% attention flops
208
208
```
209
209
210
-
In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\\text{model tflop per device} \\approx 764.67$.
210
+
In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\text{model tflop per device} \approx 764.67$.
211
211
212
212
- 94.54% of the TFLOPs are attributed to learnable weight and 5.46% are attributed to attention.
213
213
- As you will see next, this number is important for calculating performance metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU).
As shown in `seconds: 5.667`, $\\text{measured step time in seconds} \\approx 5.667$ (rounded).
247
+
As shown in `seconds: 5.667`, $\text{measured step time in seconds} \approx 5.667$ (rounded).
248
248
249
249
**TFLOP per second per device**
250
250
251
251
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L211-L213) as
252
252
253
-
$$\\text{tflop/s/device} = \\frac{\\text{model tflop per device}}{\\text{measured step time in seconds}}$$
253
+
$$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured step time in seconds}}$$
254
254
255
255
- Here we have `TFLOP/s/device: 134.924`. Let's try to verify manually: $764.67 / 5.667 = 134.934$. Not exactly the same but close, since the both tflop and time are rounded in the log.
256
256
- Further, we can calculate **Model FLOPs Utilization (MFU)** from this:
For TPU v5p, $\\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal.
260
+
For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal.
261
261
262
262
**Tokens per second per device (throughput)**
263
263
264
264
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L215-L217) as
265
265
266
-
$$\\text{token/s/device} = \\frac{\\text{number of tokens per device}}{\\text{measured step time in seconds}}$$
266
+
$$\text{token/s/device} = \frac{\text{number of tokens per device}}{\text{measured step time in seconds}}$$
267
267
268
268
- The numerator is from [calculate_tokens_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L148)
269
269
270
-
$$\\text{number of tokens per device} = \\text{per device batch size} \\times \\text{max target length}$$
270
+
$$\text{number of tokens per device} = \text{per device batch size} \times \text{max target length}$$
271
271
272
-
- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \\times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log.
272
+
- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log.
273
273
274
274
### 4.2. Learning metrics
275
275
276
276
**Loss**. The loss is the key indicator of learning progress, which should decrease over training steps. In this example, the loss is `12.038` at Step 0 and decreases to `10.374` at Step 9. Ideally, we want the loss to converge to a small value with sufficiently large training steps.
277
277
278
-
**Total weights**. When discussing the throughput, we have $\\text{number of tokens} = \\text{per device batch size} \\times \\text{max target length} \\times \\text{number of device}$. In this example, $\\text{number of tokens} = 24 \\times 2048 \\times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151).
278
+
**Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151).
279
279
280
280
- Here we see `total_weights: 196608` for all steps. This is because we are using `dataset_type=synthetic`, where all sentences are generated with a length of `max_target_length=2048`. As a result, there are no pad tokens and total weights = number of tokens.
281
281
- However, in real datasets, sentences can have variable lengths and total weights < number of tokens. For example, we can set `dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1'`, and will see total weights smaller than `196608`:
0 commit comments