Skip to content

Commit 6a65d66

Browse files
author
Charles Li
committed
Move gcs_benchmarks/standalone tools to src/maxtext/utils
1 parent 5905690 commit 6a65d66

5 files changed

Lines changed: 14 additions & 42 deletions

File tree

RESTRUCTURE.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,6 @@ comments, or questions by creating a new
299299
├── dev/
300300
│ ├── code_style.sh
301301
│ └── unit_test_and_lint.sh
302-
├── gcs_benchmarks/
303-
│ ├── standalone_checkpointer.py
304-
│ └── standalone_dataloader.py
305302
├── orchestration/
306303
│ ├── gpu_multi_process_run.sh
307304
│ ├── multihost_job.py

tools/gcs_benchmarks/standalone_checkpointer.py renamed to src/maxtext/utils/standalone_checkpointer.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -53,13 +53,9 @@ def checkpoint_loop(config, state=None):
5353
"""
5454
model = from_config(config)
5555
mesh = model.mesh
56-
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(
57-
config, model, mesh
58-
)
56+
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh)
5957

60-
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(
61-
model, tx, config, init_rng, mesh, is_training=True
62-
)
58+
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
6359
# A barrier to sync all hosts before starting to restore checkpoint
6460
jax.experimental.multihost_utils.sync_global_devices("Barrier before load")
6561
checkpoint_load_start = datetime.datetime.now()
@@ -82,30 +78,24 @@ def checkpoint_loop(config, state=None):
8278
if state is not None: # Checkpoint was available for restore
8379
if jax.process_index() == 0:
8480
max_logging.log(
85-
"STANDALONE CHECKPOINTER : Checkpoint restored in :"
86-
f" {checkpoint_load_end - checkpoint_load_start}"
81+
"STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}"
8782
)
8883
else: # Checkpoint was unavailable, state needs to be initialized
89-
state, _, _, _ = maxtext_utils.setup_training_state(
90-
model, None, tx, config, init_rng, mesh, checkpoint_manager
91-
)
84+
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager)
9285
state = add_entropy_to_checkpoint(state)
9386

9487
start_step = get_first_step(state) # this is the start_step for training
9588
for step in np.arange(start_step, config.steps):
9689
if checkpoint_manager is not None:
9790
start_time = datetime.datetime.now()
9891
# A barrier to sync all hosts before starting to save checkpoint
99-
jax.experimental.multihost_utils.sync_global_devices(
100-
"Barrier before save"
101-
)
92+
jax.experimental.multihost_utils.sync_global_devices("Barrier before save")
10293
if checkpointing.save_checkpoint(checkpoint_manager, int(step), state):
10394
checkpoint_manager.wait_until_finished()
10495
end_time = datetime.datetime.now()
10596
if jax.process_index() == 0:
10697
max_logging.log(
107-
"STANDALONE CHECKPOINTER : Checkpoint saved in"
108-
f" {end_time - start_time} ,step {step}, on host 0"
98+
"STANDALONE CHECKPOINTER : Checkpoint saved in" f" {end_time - start_time} ,step {step}, on host 0"
10999
)
110100

111101
return state
@@ -123,12 +113,8 @@ def add_entropy_to_checkpoint(state):
123113
state: Returns state with entropy added to the optimizer state.
124114
"""
125115
opt_0 = state.opt_state[0]
126-
opt_0 = opt_0._replace(
127-
mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)
128-
)
129-
opt_0 = opt_0._replace(
130-
nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)
131-
)
116+
opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params))
117+
opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params))
132118
new_opt = [opt_0] + list(state.opt_state[1:])
133119
state = state.replace(opt_state=new_opt)
134120
return state

tools/gcs_benchmarks/standalone_dataloader.py renamed to src/maxtext/utils/standalone_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

tests/integration/standalone_dl_ckpt_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
""" Tests for the standalone_checkpointer.py """
1616
import unittest
1717
import pytest
18-
from tools.gcs_benchmarks.standalone_checkpointer import main as sckpt_main
19-
from tools.gcs_benchmarks.standalone_dataloader import main as sdl_main
18+
from maxtext.utils.standalone_checkpointer import main as sckpt_main
19+
from maxtext.utils.standalone_dataloader import main as sdl_main
2020
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
2121
from maxtext.common.gcloud_stub import is_decoupled
2222

@@ -50,6 +50,7 @@ def _get_random_test_name(self, test_name):
5050

5151
@pytest.mark.integration_test
5252
@pytest.mark.tpu_only
53+
@pytest.mark.scheduled_only
5354
def test_standalone_dataloader(self):
5455
random_run_name = self._get_random_test_name("standalone_dataloader")
5556
sdl_main(
@@ -68,6 +69,7 @@ def test_standalone_dataloader(self):
6869

6970
@pytest.mark.integration_test
7071
@pytest.mark.tpu_only
72+
@pytest.mark.scheduled_only
7173
def test_standalone_checkpointer(self):
7274
random_run_name = self._get_random_test_name("standalone_checkpointer")
7375
# checkpoint at 50

tools/gcs_benchmarks/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)