Skip to content

Commit 1963688

Browse files
Add elastic pause/resume functionality to MaxText.
PiperOrigin-RevId: 895636178
1 parent 381dcdd commit 1963688

6 files changed

Lines changed: 469 additions & 2 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,3 +1194,9 @@ distill_temperature: 1.0
11941194
# 0.0 value disables this feature.
11951195
distill_beta: 0.0
11961196
distill_layer_indices: None
1197+
1198+
##### Elastic training parameters
1199+
# Elastic training is Pathways-specific and does not work on McJAX.
1200+
elastic_enabled: false
1201+
elastic_timeout_seconds: 300
1202+
elastic_max_retries: 10

src/maxtext/configs/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,26 @@ class Goodput(BaseModel):
15511551
enable_gcp_step_deviation_metrics: bool = Field(True, description="Enable GCP step deviation metrics.")
15521552

15531553

1554+
class ElasticTraining(BaseModel):
1555+
"""Configuration for elastic training and fault tolerance.
1556+
1557+
Elastic training is Pathways-specific and does not work on McJAX.
1558+
"""
1559+
1560+
elastic_enabled: bool = Field(False, description="Whether to enable elastic training.")
1561+
elastic_timeout_seconds: int = Field(
1562+
300,
1563+
description=(
1564+
"The maximum number of seconds to wait for `elastic_minimum_slice_count` slices to become active. If this"
1565+
" timeout is reached during any retry attempt, a `TimeoutError` is raised and training fails."
1566+
),
1567+
)
1568+
elastic_max_retries: int = Field(
1569+
10,
1570+
description="The maximum number of times to retry training when a slice failure occurs or when scaling up.",
1571+
)
1572+
1573+
15541574
class GcpMonitoring(BaseModel):
15551575
"""Configuration for GCP-specific workload monitoring."""
15561576

@@ -1948,6 +1968,7 @@ class MaxTextConfig(
19481968
Checkpointing,
19491969
OrbaxStorage,
19501970
EmergencyCheckpointing,
1971+
ElasticTraining,
19511972
# Data Types and Quantization
19521973
DataTypes,
19531974
Quantization,
@@ -2457,6 +2478,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24572478
# H. RUN ALL CROSS-FIELD VALIDATIONS
24582479
if self.load_parameters_path and self.load_full_state_path:
24592480
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
2481+
if self.elastic_enabled and not self.enable_single_controller:
2482+
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
24602483
if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing:
24612484
raise ValueError("You must set enable_checkpointing=True to load a checkpoint.")
24622485
if self.enable_multi_tier_checkpointing:

src/maxtext/trainers/pre_train/train.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from maxtext.configs import pyconfig
4242
from maxtext.common.common_types import ShardMode
4343
from maxtext.utils.globals import EPS
44+
from maxtext.utils import elastic_utils
4445
# Placeholder: internal
4546

4647
# pylint: disable=too-many-positional-arguments
@@ -675,11 +676,35 @@ def run(config, recorder, diagnostic_config):
675676
train_loop(config, recorder)
676677

677678

679+
def get_train_func(config, recorder, diagnostic_config, argv):
680+
"""Returns the train function, wrapping in elastic_retry if elastic training is enabled."""
681+
if config.elastic_enabled:
682+
max_logging.log("Elastic utils: Elastic training enabled.")
683+
684+
def elastic_train_wrapper(argv: Sequence[str]) -> None:
685+
"""Wrapper for elastic training initializes variables and runs the train loop."""
686+
elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv)
687+
run(
688+
elastic_config,
689+
elastic_recorder,
690+
elastic_diagnostic_config,
691+
)
692+
693+
train_func = elastic_utils.elastic_retry(config)(functools.partial(elastic_train_wrapper, argv=argv))
694+
else:
695+
# Use the already initialized variables
696+
def train_func():
697+
run(config, recorder, diagnostic_config)
698+
699+
return train_func
700+
701+
678702
def main(argv: Sequence[str]) -> None:
679703
config, recorder, diagnostic_config = initialize(argv)
680704
record_goodput(recorder, RECORD_JOB_START_TIME)
705+
train_func = get_train_func(config, recorder, diagnostic_config, argv)
681706
with maybe_monitor_goodput(config):
682-
run(config, recorder, diagnostic_config)
707+
train_func()
683708

684709

685710
if __name__ == "__main__":

src/maxtext/utils/elastic_utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for Elastic Training."""
16+
17+
import functools
18+
import jax
19+
from maxtext.utils import gcs_utils
20+
from maxtext.utils import max_logging
21+
import pathwaysutils
22+
from pathwaysutils.elastic import manager
23+
24+
25+
elastic_manager: manager.Manager | None = None
26+
27+
28+
def elastic_enabled(config) -> bool:
29+
"""Returns whether elastic mode is enabled."""
30+
return pathwaysutils.is_pathways_backend_used() and config.elastic_enabled
31+
32+
33+
def clean_up_checkpoints(checkpoint_dir: str):
34+
"""Cleans up incomplete checkpoints after an elastic event."""
35+
max_logging.log("Elastic utils: Checking for incomplete checkpoint after an elastic event...")
36+
checkpoint_dir = gcs_utils.add_trailing_slash(checkpoint_dir)
37+
38+
# 1. List the "directories" (steps)
39+
checkpoints = gcs_utils.gcs_list_directories(checkpoint_dir)
40+
41+
# 2. Filter for directories that are numbers
42+
checkpoints = [cp for cp in checkpoints if cp.isdigit()]
43+
44+
if not checkpoints:
45+
max_logging.log("Found no existing checkpoints. Continuing")
46+
return
47+
48+
# Sort naturally (numerical sort) and get the last one
49+
checkpoints.sort(key=int)
50+
latest_checkpoint_name = checkpoints[-1]
51+
latest_checkpoint_path = f"{checkpoint_dir}{latest_checkpoint_name}/"
52+
53+
max_logging.log(f"Checking latest checkpoint: {latest_checkpoint_path}")
54+
55+
# 3. Check for commit_success file
56+
success_markers = gcs_utils.gcs_glob_pattern(f"{latest_checkpoint_path}commit_success*")
57+
58+
if not success_markers:
59+
max_logging.log(f"No commit_success file found. Deleting {latest_checkpoint_path}...")
60+
# TODO: Use Orbax 'Cancel Ongoing Checkpointing' API when available to
61+
# prevent deleting a checkpoint that is currently being written.
62+
gcs_utils.gcs_delete_directory(latest_checkpoint_path)
63+
else:
64+
max_logging.log(f"Found commit_success file. Keeping {latest_checkpoint_path}.")
65+
66+
67+
def live_devices():
68+
"""Returns the list of live devices."""
69+
global elastic_manager
70+
# If pathways is not used or elastic_manager is not initialized, return all devices
71+
if pathwaysutils.is_pathways_backend_used():
72+
if elastic_manager is None:
73+
elastic_manager = manager.Manager()
74+
# Filter devices that are in active slices
75+
return [d for d in jax.devices() if d.slice_index in elastic_manager.active_slice_indices]
76+
return jax.devices()
77+
78+
79+
def chain_callbacks(*funcs):
80+
"""Helper function to chain callbacks."""
81+
82+
def wrapper():
83+
for func in funcs:
84+
func()
85+
86+
return wrapper
87+
88+
89+
def elastic_retry(config, callback_fn=None):
90+
"""Decorator for elastic retry.
91+
92+
If an elastic event occurs, the decorator will retry the decorated function
93+
up to `config.elastic_max_retries` times.
94+
Before each retry, it cleans up partial checkpoints by calling
95+
`clean_up_checkpoints`. If `callback_fn` is provided, it is
96+
called after `clean_up_checkpoints`.
97+
98+
Args:
99+
config: Config object.
100+
callback_fn: Optional callback function to be called after
101+
`clean_up_checkpoints` on an elastic event.
102+
103+
Returns:
104+
A decorator for elastic retry.
105+
"""
106+
global elastic_manager
107+
if not elastic_enabled(config):
108+
msg = (
109+
"Elastic training requires the Pathways backend, and elastic_enabled"
110+
" must be set to True: current config.elastic_enabled:"
111+
f" {config.elastic_enabled}, pathways backend used:"
112+
f" {pathwaysutils.is_pathways_backend_used()}"
113+
)
114+
raise ValueError(msg)
115+
116+
max_logging.log("Elastic Retry Enabled")
117+
if elastic_manager is None:
118+
elastic_manager = manager.Manager()
119+
120+
cleanup_partial = functools.partial(clean_up_checkpoints, config.checkpoint_dir)
121+
122+
if callback_fn is None:
123+
effective_callback = cleanup_partial
124+
else:
125+
effective_callback = chain_callbacks(cleanup_partial, callback_fn)
126+
127+
return elastic_manager.elastic_retry(
128+
max_retries=config.elastic_max_retries,
129+
timeout=config.elastic_timeout_seconds,
130+
on_elastic_event_callback=effective_callback,
131+
)

src/maxtext/utils/gcs_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Common GCS Utils needed by multiple modules"""
15+
"""Common GCS Utils needed by multiple modules"""
1616
import shutil
1717
import json
1818
import os
1919
import socket
2020
from pathlib import Path
2121
from etils import epath
2222
import uuid
23+
from concurrent.futures import ThreadPoolExecutor
2324

2425
import yaml
2526

@@ -168,6 +169,35 @@ def gcs_list_directories(directory_path):
168169
return directories
169170

170171

172+
def gcs_delete_directory(directory_path: str):
173+
"""Deletes a "directory" (all blobs with the prefix) from GCS.
174+
175+
Args:
176+
directory_path: The GCS path (gs://...) representing the "directory" to delete.
177+
"""
178+
if not _gcs_guard("gcs_delete_directory"):
179+
return
180+
storage_client = storage.Client()
181+
bucket_name, directory_prefix = parse_gcs_bucket_and_prefix(directory_path)
182+
bucket = storage_client.bucket(bucket_name)
183+
184+
# Ensures the prefix has a trailing slash to avoid deleting more than intended.
185+
if not directory_prefix.endswith("/"):
186+
directory_prefix += "/"
187+
188+
blobs = list(bucket.list_blobs(prefix=directory_prefix))
189+
if blobs:
190+
# Uses a ThreadPoolExecutor to delete blobs in parallel to match gsutil -m performance.
191+
def _delete_blob(blob):
192+
try:
193+
blob.delete()
194+
except Exception as e: # pylint: disable=broad-except
195+
max_logging.log(f"Error deleting blob {blob.name}: {e}")
196+
197+
with ThreadPoolExecutor(max_workers=32) as executor:
198+
executor.map(_delete_blob, blobs)
199+
200+
171201
def gcs_glob_pattern(pattern):
172202
"""
173203
Globs GCS files and returns a list of full GCS paths.

0 commit comments

Comments
 (0)