Skip to content

Commit a86e2fb

Browse files
Merge pull request #3440 from Steboss:main
PiperOrigin-RevId: 889346776
2 parents 2052c22 + 5548c58 commit a86e2fb

4 files changed

Lines changed: 107 additions & 0 deletions

File tree

src/maxtext/common/metric_logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import json
2020
import os
21+
import sys
2122
import queue
2223
import enum
2324

@@ -123,6 +124,9 @@ def write_metrics(self, metrics, step, is_training=True):
123124
if self.config.managed_mldiagnostics:
124125
self.write_metrics_to_managed_mldiagnostics(metrics, step)
125126

127+
if is_training:
128+
self._maybe_abort_after_write_metrics(metrics)
129+
126130
def log_metrics(self, metrics, step, is_training):
127131
"""Logs metrics via max_logging."""
128132
if is_training:
@@ -214,6 +218,16 @@ def _is_profiler_boundary_step(self, step):
214218
}
215219
return step in boundary_steps
216220

221+
def _maybe_abort_after_write_metrics(self, metrics):
222+
""" This function checks whether we have nan or inf values in training"""
223+
loss = metrics["scalar"].get("learning/loss")
224+
if self.config.abort_on_nan_loss and np.isnan(loss):
225+
max_logging.log("Aborting training due to NaN loss.")
226+
sys.exit(1)
227+
if self.config.abort_on_inf_loss and np.isinf(loss):
228+
max_logging.log("Aborting training due to Inf loss.")
229+
sys.exit(1)
230+
217231
def write_metrics_locally(self, metrics, step):
218232
"""Writes metrics locally for testing."""
219233
with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file:

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,8 @@ decode_sampling_temperature: 1.
877877
eval_interval: -1 # the specific number of train step between eval_step
878878
eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data
879879
target_eval_loss: 0. # early stop once reaching target eval_loss
880+
abort_on_nan_loss: True # Check for NaN and abort if found in training loss
881+
abort_on_inf_loss: True # Check for Inf and abort if found in training loss
880882

881883
# Goodput parameters
882884
enable_goodput_recording: False

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ class TrainingLoop(BaseModel):
11431143
0.0,
11441144
description="If set, training will stop early when this evaluation loss is reached.",
11451145
)
1146+
abort_on_nan_loss: bool = Field(True, description="Check for NaN values and abort training.")
1147+
abort_on_inf_loss: bool = Field(True, description="Check for Inf values and abort training.")
11461148
enable_dropout: bool = Field(True, description="Enables dropout in the model.")
11471149
dropout_rate: float = Field(0.0, ge=0.0, le=1.0, description="The dropout rate.")
11481150
enable_data_shuffling: bool = Field(True, description="Enables shuffling of the training data.")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for monitoring metrics"""
16+
import unittest
17+
from types import SimpleNamespace
18+
from unittest import mock
19+
20+
import numpy as np
21+
22+
from maxtext.common.metric_logger import MetricLogger
23+
24+
25+
class MetricLoggerAbortTest(unittest.TestCase):
26+
def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss):
27+
logger = MetricLogger.__new__(MetricLogger) # skip __init__
28+
logger.config = SimpleNamespace(
29+
abort_on_nan_loss=abort_on_nan_loss,
30+
abort_on_inf_loss=abort_on_inf_loss,
31+
enable_tensorboard=True,
32+
metrics_file="/tmp/fake_metrics.jsonl",
33+
gcs_metrics=True,
34+
managed_mldiagnostics=True,
35+
)
36+
return logger
37+
38+
def _metrics(self, loss):
39+
return {"scalar": {"learning/loss": loss}}
40+
41+
@mock.patch("jax.process_index", return_value=0)
42+
def test_abort_on_nan_exits_after_writes(self, _):
43+
logger = self._make_logger(True, False)
44+
45+
with (
46+
mock.patch.object(logger, "log_metrics") as log_metrics,
47+
mock.patch.object(logger, "write_metrics_to_tensorboard") as tb,
48+
mock.patch.object(logger, "write_metrics_locally") as local,
49+
mock.patch.object(logger, "write_metrics_for_gcs") as gcs,
50+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics") as mldiag,
51+
):
52+
with self.assertRaises(SystemExit) as cm:
53+
logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)
54+
55+
self.assertEqual(cm.exception.code, 1)
56+
log_metrics.assert_called_once()
57+
tb.assert_called_once()
58+
local.assert_called_once()
59+
gcs.assert_called_once()
60+
mldiag.assert_called_once()
61+
62+
@mock.patch("jax.process_index", return_value=0)
63+
def test_abort_on_inf_exits_after_writes(self, _):
64+
logger = self._make_logger(False, True)
65+
with mock.patch.object(logger, "log_metrics"), \
66+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
67+
mock.patch.object(logger, "write_metrics_locally"), \
68+
mock.patch.object(logger, "write_metrics_for_gcs"), \
69+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"):
70+
with self.assertRaises(SystemExit):
71+
logger.write_metrics(self._metrics(np.inf), step=1, is_training=True)
72+
73+
def test_finite_loss_does_not_exit(self):
74+
logger = self._make_logger(True, True)
75+
with mock.patch.object(logger, "log_metrics"), \
76+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
77+
mock.patch.object(logger, "write_metrics_locally"), \
78+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
79+
mock.patch("jax.process_index", return_value=1): # skip gcs branch
80+
logger.write_metrics(self._metrics(1.23), step=1, is_training=True)
81+
82+
def test_abort_flags_disabled_does_not_exit(self):
83+
logger = self._make_logger(False, False)
84+
with mock.patch.object(logger, "log_metrics"), \
85+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
86+
mock.patch.object(logger, "write_metrics_locally"), \
87+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
88+
mock.patch("jax.process_index", return_value=1):
89+
logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)

0 commit comments

Comments
 (0)