Skip to content

Commit a4b4b90

Browse files
Merge pull request #3571 from ROCm:fix-decoupled-ut-behavior
PiperOrigin-RevId: 895960198
2 parents 1963688 + 56348c5 commit a4b4b90

4 files changed

Lines changed: 23 additions & 8 deletions

File tree

src/maxtext/common/gcloud_stub.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ def get_bucket(self, *a, **k): # pylint: disable=unused-argument
309309
def bucket(self, *a, **k): # pylint: disable=unused-argument
310310
return _StubBucket()
311311

312+
def list_blobs(self, *a, **k): # pylint: disable=unused-argument
313+
return iter([])
314+
312315
return SimpleNamespace(Client=_StubClient, _IS_STUB=True)
313316

314317

tests/integration/smoke/train_tokenizer_smoke_test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import unittest
1919
import pytest
2020

21+
from maxtext.common.gcloud_stub import is_decoupled
2122
from maxtext.input_pipeline import input_pipeline_utils
2223
from maxtext.trainers.tokenizer import train_tokenizer
24+
from tests.utils.test_helpers import get_test_dataset_path
2325

2426

2527
class TrainTokenizerFormatTest(unittest.TestCase):
@@ -49,17 +51,26 @@ def _run_format_test(self, file_pattern, file_type):
4951

5052
@pytest.mark.cpu_only
5153
def test_parquet(self):
52-
self._run_format_test("gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", "parquet")
54+
path = os.path.join(get_test_dataset_path(), "hf", "c4", "c4-train-00000-of-01637.parquet")
55+
self._run_format_test(path, "parquet")
5356

5457
@pytest.mark.cpu_only
5558
def test_arrayrecord(self):
56-
self._run_format_test(
57-
"gs://maxtext-dataset/array-record/c4/en/3.0.1/c4-train.array_record-00000-of-01024", "arrayrecord"
58-
)
59+
dataset_root = get_test_dataset_path()
60+
if is_decoupled():
61+
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.array_record-00000-of-00008")
62+
else:
63+
path = os.path.join(dataset_root, "array-record", "c4", "en", "3.0.1", "c4-train.array_record-00000-of-01024")
64+
self._run_format_test(path, "arrayrecord")
5965

6066
@pytest.mark.cpu_only
6167
def test_tfrecord(self):
62-
self._run_format_test("gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-00000-of-01024", "tfrecord")
68+
dataset_root = get_test_dataset_path()
69+
if is_decoupled():
70+
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "__local_c4_builder-train.tfrecord-00000-of-00008")
71+
else:
72+
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.tfrecord-00000-of-01024")
73+
self._run_format_test(path, "tfrecord")
6374

6475

6576
if __name__ == "__main__":

tests/unit/diloco_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def loss_fn(params, batch):
269269
chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params)
270270

271271
@pytest.mark.cpu_only
272+
@pytest.mark.tpu_backend
272273
def test_diloco_qwen3_moe_two_slices(self):
273274
temp_dir = gettempdir()
274275
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle")

tests/unit/grain_data_processing_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,9 @@ def setUp(self):
103103
grain_train_files = os.path.join(
104104
dataset_root,
105105
"c4",
106-
"array-record",
107106
"en",
108107
"3.0.1",
109-
"c4-train.array_record-00000-of-01024",
108+
"c4-train.array_record-00000-of-00008",
110109
)
111110
base_output_directory = get_test_base_output_directory()
112111
else:
@@ -384,7 +383,7 @@ def setUp(self):
384383
"c4",
385384
"en",
386385
"3.0.1",
387-
"c4-train.tfrecord-00000-of-01024",
386+
"__local_c4_builder-train.tfrecord-00000-of-00008",
388387
)
389388
base_output_directory = get_test_base_output_directory()
390389
else:
@@ -427,6 +426,7 @@ def setUp(self):
427426
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
428427

429428

429+
@pytest.mark.external_training
430430
class GrainSFTParquetProcessingTest(unittest.TestCase):
431431
"""Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset."""
432432

0 commit comments

Comments
 (0)