|
42 | 42 | import jax |
43 | 43 | import grain.python as grain |
44 | 44 |
|
| 45 | +from maxtext.input_pipeline import input_pipeline_utils |
45 | 46 | from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT |
46 | 47 | from maxtext.utils import gcs_utils |
47 | 48 |
|
|
50 | 51 | "grain_train_files", None, "File pattern for training data (local or gs://)", required=True |
51 | 52 | ) |
52 | 53 | _GRAIN_FILE_TYPE = flags.DEFINE_string( |
53 | | - "grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'." |
| 54 | + "grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord', 'tfrecord'." |
54 | 55 | ) |
55 | 56 | _DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).") |
56 | 57 | _VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") |
@@ -82,27 +83,23 @@ def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys: |
82 | 83 | dataset = grain.MapDataset.source(data_files) |
83 | 84 | dataset = dataset.map(grain.experimental.ParquetIterDataset) |
84 | 85 | dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files)) |
| 86 | + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=list(data_keys))) |
85 | 87 | return iter(dataset) |
86 | 88 | elif data_file_type == "arrayrecord": |
87 | | - from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel |
88 | | - |
89 | 89 | source = grain.ArrayRecordDataSource(data_files) |
90 | 90 | dataset = grain.MapDataset.source(source) |
91 | | - |
92 | | - def _parse_example(raw_bytes): |
93 | | - example = example_pb2.Example() |
94 | | - example.ParseFromString(raw_bytes) |
95 | | - features = example.features.feature |
96 | | - parsed = {} |
97 | | - for col in data_keys: |
98 | | - if col in features: |
99 | | - parsed[col] = features[col].bytes_list.value[0] |
100 | | - return parsed |
101 | | - |
102 | | - dataset = dataset.map(_parse_example) |
| 91 | + dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True)) |
| 92 | + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True)) |
| 93 | + return iter(dataset) |
| 94 | + elif data_file_type == "tfrecord": |
| 95 | + dataset = grain.MapDataset.source(data_files) |
| 96 | + dataset = dataset.map(input_pipeline_utils.make_tfrecord_iter_dataset) |
| 97 | + dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files)) |
| 98 | + dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True)) |
| 99 | + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True)) |
103 | 100 | return iter(dataset) |
104 | 101 | else: |
105 | | - raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.") |
| 102 | + raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet', 'arrayrecord', or 'tfrecord'.") |
106 | 103 |
|
107 | 104 |
|
108 | 105 | def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]: |
|
0 commit comments