Skip to content

Commit 30dd96e

Browse files
committed
Support Grain with Tfrecord in train_tokenizer
1 parent 5478bad commit 30dd96e

2 files changed

Lines changed: 79 additions & 16 deletions

File tree

src/maxtext/trainers/tokenizer/train_tokenizer.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import jax
4343
import grain.python as grain
4444

45+
from maxtext.input_pipeline import input_pipeline_utils
4546
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
4647
from maxtext.utils import gcs_utils
4748

@@ -50,7 +51,7 @@
5051
"grain_train_files", None, "File pattern for training data (local or gs://)", required=True
5152
)
5253
_GRAIN_FILE_TYPE = flags.DEFINE_string(
53-
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'."
54+
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord', 'tfrecord'."
5455
)
5556
_DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).")
5657
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
@@ -82,27 +83,23 @@ def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys:
8283
dataset = grain.MapDataset.source(data_files)
8384
dataset = dataset.map(grain.experimental.ParquetIterDataset)
8485
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
86+
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=list(data_keys)))
8587
return iter(dataset)
8688
elif data_file_type == "arrayrecord":
87-
from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel
88-
8989
source = grain.ArrayRecordDataSource(data_files)
9090
dataset = grain.MapDataset.source(source)
91-
92-
def _parse_example(raw_bytes):
93-
example = example_pb2.Example()
94-
example.ParseFromString(raw_bytes)
95-
features = example.features.feature
96-
parsed = {}
97-
for col in data_keys:
98-
if col in features:
99-
parsed[col] = features[col].bytes_list.value[0]
100-
return parsed
101-
102-
dataset = dataset.map(_parse_example)
91+
dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True))
92+
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True))
93+
return iter(dataset)
94+
elif data_file_type == "tfrecord":
95+
dataset = grain.MapDataset.source(data_files)
96+
dataset = dataset.map(input_pipeline_utils.make_tfrecord_iter_dataset)
97+
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
98+
dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True))
99+
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True))
103100
return iter(dataset)
104101
else:
105-
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.")
102+
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet', 'arrayrecord', or 'tfrecord'.")
106103

107104

108105
def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Smoke tests for train_tokenizer file format support."""
16+
17+
import os
18+
import unittest
19+
import pytest
20+
21+
from maxtext.input_pipeline import input_pipeline_utils
22+
from maxtext.trainers.tokenizer import train_tokenizer
23+
24+
25+
class TrainTokenizerFormatTest(unittest.TestCase):
26+
"""Smoke-tests that train_tokenizer runs end-to-end for each supported file format."""
27+
28+
def _run_format_test(self, file_pattern, file_type):
29+
"""Uses a tiny corpus; the resulting tokenizer is not stored — only verify
30+
it can be loaded and used for encode/decode.
31+
"""
32+
output_path = os.path.join("tests", f"test_tokenizer_{file_type}")
33+
try:
34+
dataset_iter = train_tokenizer.build_grain_iterator(file_pattern, file_type)
35+
train_tokenizer.train_tokenizer(
36+
dataset_iter,
37+
vocab_path=output_path,
38+
vocab_size=512,
39+
max_corpus_chars=10_000,
40+
)
41+
tok = input_pipeline_utils.get_tokenizer(output_path, "sentencepiece", add_bos=False, add_eos=False)
42+
text = "This is a test"
43+
tokens = tok.encode(text)
44+
self.assertGreater(len(tokens), 0)
45+
self.assertEqual(tok.decode(tokens), text)
46+
finally:
47+
if os.path.exists(output_path):
48+
os.remove(output_path)
49+
50+
@pytest.mark.cpu_only
51+
def test_parquet(self):
52+
self._run_format_test("gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", "parquet")
53+
54+
@pytest.mark.cpu_only
55+
def test_arrayrecord(self):
56+
self._run_format_test(
57+
"gs://maxtext-dataset/array-record/c4/en/3.0.1/c4-train.array_record-00000-of-01024", "arrayrecord"
58+
)
59+
60+
@pytest.mark.cpu_only
61+
def test_tfrecord(self):
62+
self._run_format_test("gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-00000-of-01024", "tfrecord")
63+
64+
65+
if __name__ == "__main__":
66+
unittest.main()

0 commit comments

Comments
 (0)