Skip to content

Commit 655c103

Browse files
Merge pull request #3556 from CIeNET-International:charlesli/gcs_benchmarks_move
PiperOrigin-RevId: 893778063
2 parents 5905690 + ebd77bb commit 655c103

8 files changed

Lines changed: 20 additions & 46 deletions

File tree

RESTRUCTURE.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,6 @@ comments, or questions by creating a new
299299
├── dev/
300300
│ ├── code_style.sh
301301
│ └── unit_test_and_lint.sh
302-
├── gcs_benchmarks/
303-
│ ├── standalone_checkpointer.py
304-
│ └── standalone_dataloader.py
305302
├── orchestration/
306303
│ ├── gpu_multi_process_run.sh
307304
│ ├── multihost_job.py

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
os.path.join("run_maxtext", "run_maxtext_via_multihost_job.md"),
120120
os.path.join("run_maxtext", "run_maxtext_via_multihost_runner.md"),
121121
os.path.join("reference", "core_concepts", "llm_calculator.ipynb"),
122+
os.path.join("reference", "api.rst"),
123+
os.path.join("reference", "api_generated", "MaxText*.rst"),
122124
os.path.join("reference", "api_generated", "modules.rst"),
123125
os.path.join("reference", "api_generated", "dependencies.github_deps.rst"),
124126
os.path.join("reference", "api_generated", "dependencies.github_deps.install_pre_train_deps.rst"),

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
3434

3535
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class.
3636
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
37-
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/4e44e065cc6379e76f9f1ac4785f81c05cafb58f/src/dependencies/scripts/setup_gcsfuse.sh). The script configures some parameters for the mount.
37+
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh). The script configures some parameters for the mount.
3838

3939
```sh
40-
bash tools/setup/setup_gcsfuse.sh \
40+
bash src/dependencies/scripts/setup_gcsfuse.sh \
4141
DATASET_GCS_BUCKET=${BUCKET_NAME?} \
4242
MOUNT_PATH=${MOUNT_PATH?} \
4343
[FILE_PATH=${MOUNT_PATH?}/my_dataset]
@@ -47,7 +47,7 @@ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pr
4747

4848
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
4949

50-
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/4e44e065cc6379e76f9f1ac4785f81c05cafb58f/src/dependencies/scripts/setup_gcsfuse.sh) to avoid gcsfuse throttling.
50+
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5151

5252
3. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
5353

docs/tutorials/pretraining.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ eval metrics after step: 9, loss=9.420, total_weights=75264.0
8787

8888
Grain is a library for reading data for training and evaluating JAX models. It is the recommended input pipeline for determinism and resilience! It supports data formats like ArrayRecord and Parquet. You can check [Grain pipeline](../guides/data_input_pipeline/data_input_grain.md) for more details.
8989

90-
**Data preparation**: You need to download data to a Cloud Storage bucket, and read data via Cloud Storage Fuse with [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/4e44e065cc6379e76f9f1ac4785f81c05cafb58f/src/dependencies/scripts/setup_gcsfuse.sh).
90+
**Data preparation**: You need to download data to a Cloud Storage bucket, and read data via Cloud Storage Fuse with [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh).
9191

9292
- For example, we can mount the bucket `gs://maxtext-dataset` on the local path `/tmp/gcsfuse` before training
9393
```bash

tools/gcs_benchmarks/standalone_checkpointer.py renamed to src/maxtext/utils/standalone_checkpointer.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -53,13 +53,9 @@ def checkpoint_loop(config, state=None):
5353
"""
5454
model = from_config(config)
5555
mesh = model.mesh
56-
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(
57-
config, model, mesh
58-
)
56+
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh)
5957

60-
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(
61-
model, tx, config, init_rng, mesh, is_training=True
62-
)
58+
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
6359
# A barrier to sync all hosts before starting to restore checkpoint
6460
jax.experimental.multihost_utils.sync_global_devices("Barrier before load")
6561
checkpoint_load_start = datetime.datetime.now()
@@ -82,30 +78,24 @@ def checkpoint_loop(config, state=None):
8278
if state is not None: # Checkpoint was available for restore
8379
if jax.process_index() == 0:
8480
max_logging.log(
85-
"STANDALONE CHECKPOINTER : Checkpoint restored in :"
86-
f" {checkpoint_load_end - checkpoint_load_start}"
81+
"STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}"
8782
)
8883
else: # Checkpoint was unavailable, state needs to be initialized
89-
state, _, _, _ = maxtext_utils.setup_training_state(
90-
model, None, tx, config, init_rng, mesh, checkpoint_manager
91-
)
84+
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager)
9285
state = add_entropy_to_checkpoint(state)
9386

9487
start_step = get_first_step(state) # this is the start_step for training
9588
for step in np.arange(start_step, config.steps):
9689
if checkpoint_manager is not None:
9790
start_time = datetime.datetime.now()
9891
# A barrier to sync all hosts before starting to save checkpoint
99-
jax.experimental.multihost_utils.sync_global_devices(
100-
"Barrier before save"
101-
)
92+
jax.experimental.multihost_utils.sync_global_devices("Barrier before save")
10293
if checkpointing.save_checkpoint(checkpoint_manager, int(step), state):
10394
checkpoint_manager.wait_until_finished()
10495
end_time = datetime.datetime.now()
10596
if jax.process_index() == 0:
10697
max_logging.log(
107-
"STANDALONE CHECKPOINTER : Checkpoint saved in"
108-
f" {end_time - start_time} ,step {step}, on host 0"
98+
"STANDALONE CHECKPOINTER : Checkpoint saved in" f" {end_time - start_time} ,step {step}, on host 0"
10999
)
110100

111101
return state
@@ -123,12 +113,8 @@ def add_entropy_to_checkpoint(state):
123113
state: Returns state with entropy added to the optimizer state.
124114
"""
125115
opt_0 = state.opt_state[0]
126-
opt_0 = opt_0._replace(
127-
mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)
128-
)
129-
opt_0 = opt_0._replace(
130-
nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)
131-
)
116+
opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params))
117+
opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params))
132118
new_opt = [opt_0] + list(state.opt_state[1:])
133119
state = state.replace(opt_state=new_opt)
134120
return state

tools/gcs_benchmarks/standalone_dataloader.py renamed to src/maxtext/utils/standalone_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

tests/integration/standalone_dl_ckpt_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
""" Tests for the standalone_checkpointer.py """
1616
import unittest
1717
import pytest
18-
from tools.gcs_benchmarks.standalone_checkpointer import main as sckpt_main
19-
from tools.gcs_benchmarks.standalone_dataloader import main as sdl_main
18+
from maxtext.utils.standalone_checkpointer import main as sckpt_main
19+
from maxtext.utils.standalone_dataloader import main as sdl_main
2020
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
2121
from maxtext.common.gcloud_stub import is_decoupled
2222

@@ -50,6 +50,7 @@ def _get_random_test_name(self, test_name):
5050

5151
@pytest.mark.integration_test
5252
@pytest.mark.tpu_only
53+
@pytest.mark.scheduled_only
5354
def test_standalone_dataloader(self):
5455
random_run_name = self._get_random_test_name("standalone_dataloader")
5556
sdl_main(
@@ -68,6 +69,7 @@ def test_standalone_dataloader(self):
6869

6970
@pytest.mark.integration_test
7071
@pytest.mark.tpu_only
72+
@pytest.mark.scheduled_only
7173
def test_standalone_checkpointer(self):
7274
random_run_name = self._get_random_test_name("standalone_checkpointer")
7375
# checkpoint at 50

tools/gcs_benchmarks/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)