|
19 | 19 | import tensorflow.experimental.numpy as tnp |
20 | 20 | from datasets import load_dataset, load_from_disk |
21 | 21 | import jax |
22 | | -from maxdiffusion import multihost_dataloading |
| 22 | +from maxdiffusion import multihost_dataloading, max_logging |
23 | 23 |
|
24 | 24 | AUTOTUNE = tf.data.AUTOTUNE |
25 | 25 |
|
@@ -78,92 +78,91 @@ def make_tf_iterator( |
78 | 78 | train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) |
79 | 79 | return train_iter |
80 | 80 |
|
81 | | - |
82 | | -def make_cached_tfrecord_iterator( |
83 | | - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn |
84 | | -): |
85 | | - """ |
86 | | - New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: |
87 | | - latents, input_ids, prompt_embeds, and text_embeds. |
88 | | - """ |
89 | | - |
90 | | - def _parse_tfrecord_fn(example): |
91 | | - return tf.io.parse_single_example(example, feature_description) |
92 | | - |
93 | | - # This pipeline reads the sharded files and applies the parsing and preparation. |
94 | | - filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) |
95 | | - |
96 | | - train_ds = ( |
97 | | - tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) |
98 | | - .shard(num_shards=dataloading_host_count, index=dataloading_host_index) |
99 | | - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) |
100 | | - .map(prepare_sample_fn, num_parallel_calls=AUTOTUNE) |
101 | | - .shuffle(global_batch_size * 10) |
102 | | - .batch(global_batch_size // dataloading_host_count, drop_remainder=True) |
103 | | - .repeat(-1) |
104 | | - .prefetch(AUTOTUNE) |
105 | | - ) |
106 | | - |
107 | | - # This wraps the tf.data.Dataset for use in the multi-host JAX environment. |
108 | | - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) |
109 | | - return train_iter |
110 | | - |
111 | | - |
112 | 81 | # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py |
113 | | -def make_tfrecord_iterator( |
114 | | - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn |
| 82 | +def _make_tfrecord_iterator( |
| 83 | + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool |
115 | 84 | ): |
116 | | - """Iterator for TFRecord format. For Laion dataset, |
117 | | - check out preparation script |
118 | | - maxdiffusion/pedagogical_examples/to_tfrecords.py |
119 | | - """ |
120 | 85 | # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. |
121 | 86 | # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. |
122 | 87 | # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. |
| 88 | + # if is_training is True, loads the training dataset. If False, loads the evaluation dataset. |
123 | 89 |
|
124 | 90 | # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. |
125 | 91 | is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) |
126 | 92 |
|
127 | | - if ( |
128 | | - config.cache_latents_text_encoder_outputs |
129 | | - and is_dataset_dir_valid |
130 | | - and "load_tfrecord_cached" in config.get_keys() |
131 | | - and config.load_tfrecord_cached |
132 | | - ): |
133 | | - return make_cached_tfrecord_iterator( |
134 | | - config, |
135 | | - dataloading_host_index, |
136 | | - dataloading_host_count, |
137 | | - mesh, |
138 | | - global_batch_size, |
139 | | - feature_description, |
140 | | - prepare_sample_fn, |
141 | | - ) |
| 93 | + # Determine whether to use the "cached" dataset, which requires externally |
| 94 | + # provided parsing functions, or the default one with its internal parsing logic. |
| 95 | + make_cached_tfrecord_iterator = ( |
| 96 | + config.cache_latents_text_encoder_outputs |
| 97 | + and is_dataset_dir_valid |
| 98 | + and "load_tfrecord_cached" in config.get_keys() |
| 99 | + and config.load_tfrecord_cached |
| 100 | + ) |
142 | 101 |
|
143 | 102 | feature_description = { |
144 | 103 | "moments": tf.io.FixedLenFeature([], tf.string), |
145 | 104 | "clip_embeddings": tf.io.FixedLenFeature([], tf.string), |
146 | 105 | } |
147 | 106 |
|
| 107 | + used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description |
| 108 | + |
148 | 109 | def _parse_tfrecord_fn(example): |
149 | | - return tf.io.parse_single_example(example, feature_description) |
| 110 | + return tf.io.parse_single_example(example, used_feature_description) |
150 | 111 |
|
151 | 112 | def prepare_sample(features): |
152 | 113 | moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32) |
153 | 114 | clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32) |
154 | 115 | return {"pixel_values": moments, "input_ids": clip_embeddings} |
155 | 116 |
|
156 | | - filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) |
157 | | - train_ds = ( |
158 | | - tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) |
159 | | - .shard(num_shards=dataloading_host_count, index=dataloading_host_index) |
160 | | - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) |
161 | | - .map(prepare_sample, num_parallel_calls=AUTOTUNE) |
162 | | - .shuffle(global_batch_size * 10) |
| 117 | + filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*")) |
| 118 | + ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) |
| 119 | + |
| 120 | + # --- PADDING LOGIC FOR EVALUATION --- |
| 121 | + if not is_training: |
| 122 | + num_eval_samples = 0 |
| 123 | + for _ in ds: |
| 124 | + num_eval_samples += 1 |
| 125 | + |
| 126 | + remainder = num_eval_samples % global_batch_size |
| 127 | + if remainder != 0: |
| 128 | + num_to_pad = global_batch_size - remainder |
| 129 | + # Create a dataset of padding samples from the beginning |
| 130 | + padding_ds = ds.take(num_to_pad) |
| 131 | + # Add the padding samples to the end |
| 132 | + ds = ds.concatenate(padding_ds) |
| 133 | + max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") |
| 134 | + |
| 135 | + used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample |
| 136 | + ds = ( |
| 137 | + ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) |
| 138 | + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) |
| 139 | + .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) |
| 140 | + ) |
| 141 | + if is_training: |
| 142 | + ds = ( |
| 143 | + ds.shuffle(global_batch_size * 10) |
163 | 144 | .batch(global_batch_size // dataloading_host_count, drop_remainder=True) |
164 | 145 | .repeat(-1) |
165 | 146 | .prefetch(AUTOTUNE) |
166 | | - ) |
| 147 | + ) |
| 148 | + # For Evaluation |
| 149 | + else: |
| 150 | + ds = ( |
| 151 | + ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False) |
| 152 | + .prefetch(AUTOTUNE) |
| 153 | + ) |
167 | 154 |
|
168 | | - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) |
169 | | - return train_iter |
| 155 | + iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh) |
| 156 | + return iter |
| 157 | + |
| 158 | +def make_tfrecord_iterator( |
| 159 | + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training |
| 160 | +): |
| 161 | + """Iterator for TFRecord format. For Laion dataset, |
| 162 | + check out preparation script |
| 163 | + maxdiffusion/pedagogical_examples/to_tfrecords.py |
| 164 | + """ |
| 165 | + # Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset. |
| 166 | + # TODO: refactor to support evaluation on all dataset format. |
| 167 | + dataset_path = config.train_data_dir if is_training else config.eval_data_dir |
| 168 | + return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training) |
0 commit comments