Skip to content

Commit 1e343fe

Browse files
Add checkpoint deletion options to configuration and checkpoint manager
PiperOrigin-RevId: 896170926
1 parent cb2ef64 commit 1e343fe

5 files changed

Lines changed: 49 additions & 26 deletions

File tree

docs/guides/checkpointing_solutions/gcs_checkpointing.md

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,49 @@ bucket.
99

1010
## Checkpoint loading priority
1111

12-
The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
12+
The system follows a specific order when deciding which checkpoint to load at
13+
startup. The first valid condition met is the one executed:
1314

1415
1. **Resume Current Run**: If a checkpoint already exists for the current
1516
`run_name`, the system loads the latest fully-saved checkpoint. This is the
1617
default behavior to ensure minimal state loss when resuming after an
1718
interruption.
1819
2. **Load from Specific Path**: The system checks for a user-specified path.
19-
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20-
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
21-
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
22-
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
20+
- If `load_parameters_path` is set, we load a parameter only checkpoint
21+
from that path..
22+
- If `load_full_state_path` is set, we load a full state checkpoint from
23+
that path.
24+
- **Note**: These two options are mutually exclusive and will cause an
25+
error if both are set.
26+
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state
27+
instead.
2328

2429
### MaxText configuration
2530

26-
| Flag | Description | Type | Default |
27-
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
28-
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29-
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30-
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31-
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32-
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33-
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34-
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35-
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
31+
Flag | Description | Type | Default
32+
:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------
33+
`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False`
34+
`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True`
35+
`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000`
36+
`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False`
37+
`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | `""`
38+
`checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""`
39+
`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled)
40+
`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled)
41+
`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled)
42+
`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False`
3643

3744
## Storage and format configuration
3845

39-
These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
40-
41-
| Flag | Description | Type | Default |
42-
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
43-
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44-
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
45-
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
46-
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47-
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48-
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49-
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
46+
These settings control the underlying storage mechanism
47+
([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
48+
49+
Flag | Description | Type | Default
50+
:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------
51+
`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB)
52+
`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True`
53+
`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True`
54+
`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96`
55+
`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False`
56+
`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"`
57+
`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None`

src/maxtext/common/checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def create_orbax_checkpoint_manager(
222222
colocated_python_checkpointing: bool = False,
223223
enable_single_replica_ckpt_restoring: bool = False,
224224
enable_autocheckpoint: bool = False,
225+
todelete_subdir: str | None = None,
226+
todelete_full_path: str | None = None,
225227
):
226228
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
227229
if not enable_checkpointing:
@@ -279,6 +281,8 @@ def create_orbax_checkpoint_manager(
279281
save_decision_policy=save_decision_policy,
280282
preservation_policy=preservation_policy,
281283
async_options=async_options,
284+
todelete_subdir=todelete_subdir,
285+
todelete_full_path=todelete_full_path,
282286
),
283287
logger=orbax_logger,
284288
)

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ max_num_checkpoints_to_keep: None
5959
enable_continuous_checkpointing: False
6060
# enables one replica to read the ckpt then broadcast to the rest
6161
enable_single_replica_ckpt_restoring: False
62+
# Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://)
63+
checkpoint_todelete_subdir: None
64+
# Full path to move checkpoints to before deletion.
65+
checkpoint_todelete_full_path: None
6266

6367
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
6468

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,11 @@ class Checkpointing(BaseModel):
313313
enable_single_replica_ckpt_restoring: bool = Field(
314314
False, description="One replica reads and broadcasts the checkpoint."
315315
)
316+
checkpoint_todelete_subdir: str | None = Field(
317+
None,
318+
description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)",
319+
)
320+
checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.")
316321
force_unroll: bool = Field(
317322
False,
318323
description="During param-only checkpoint generation, whether to unroll the loop.",

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def create_training_tools(config, model, mesh):
8383
config.colocated_python_checkpointing,
8484
config.enable_single_replica_ckpt_restoring,
8585
config.enable_autocheckpoint,
86+
config.checkpoint_todelete_subdir,
87+
config.checkpoint_todelete_full_path,
8688
)
8789

8890
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)