|
18 | 18 | import unittest |
19 | 19 | import pytest |
20 | 20 |
|
| 21 | +from maxtext.common.gcloud_stub import is_decoupled |
21 | 22 | from maxtext.input_pipeline import input_pipeline_utils |
22 | 23 | from maxtext.trainers.tokenizer import train_tokenizer |
| 24 | +from tests.utils.test_helpers import get_test_dataset_path |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class TrainTokenizerFormatTest(unittest.TestCase): |
@@ -49,17 +51,26 @@ def _run_format_test(self, file_pattern, file_type): |
49 | 51 |
|
50 | 52 | @pytest.mark.cpu_only |
51 | 53 | def test_parquet(self): |
52 | | - self._run_format_test("gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", "parquet") |
| 54 | + path = os.path.join(get_test_dataset_path(), "hf", "c4", "c4-train-00000-of-01637.parquet") |
| 55 | + self._run_format_test(path, "parquet") |
53 | 56 |
|
54 | 57 | @pytest.mark.cpu_only |
55 | 58 | def test_arrayrecord(self): |
56 | | - self._run_format_test( |
57 | | - "gs://maxtext-dataset/array-record/c4/en/3.0.1/c4-train.array_record-00000-of-01024", "arrayrecord" |
58 | | - ) |
| 59 | + dataset_root = get_test_dataset_path() |
| 60 | + if is_decoupled(): |
| 61 | + path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.array_record-00000-of-00008") |
| 62 | + else: |
| 63 | + path = os.path.join(dataset_root, "array-record", "c4", "en", "3.0.1", "c4-train.array_record-00000-of-01024") |
| 64 | + self._run_format_test(path, "arrayrecord") |
59 | 65 |
|
60 | 66 | @pytest.mark.cpu_only |
61 | 67 | def test_tfrecord(self): |
62 | | - self._run_format_test("gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-00000-of-01024", "tfrecord") |
| 68 | + dataset_root = get_test_dataset_path() |
| 69 | + if is_decoupled(): |
| 70 | + path = os.path.join(dataset_root, "c4", "en", "3.0.1", "__local_c4_builder-train.tfrecord-00000-of-00008") |
| 71 | + else: |
| 72 | + path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.tfrecord-00000-of-01024") |
| 73 | + self._run_format_test(path, "tfrecord") |
63 | 74 |
|
64 | 75 |
|
65 | 76 | if __name__ == "__main__": |
|
0 commit comments