Skip to content

Commit a769016

Browse files
committed
accept gcs path for data.
1 parent 5ad4ef4 commit a769016

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,16 @@ def make_tfrecord_iterator(
117117
check out preparation script
118118
maxdiffusion/pedagogical_examples/to_tfrecords.py
119119
"""
120-
121120
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
122121
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
123122
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
123+
124+
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
125+
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
126+
124127
if (
125128
config.cache_latents_text_encoder_outputs
126-
and os.path.isdir(config.dataset_save_location)
129+
and is_dataset_dir_valid
127130
and "load_tfrecord_cached" in config.get_keys()
128131
and config.load_tfrecord_cached
129132
):

0 commit comments

Comments
 (0)