File tree Expand file tree Collapse file tree
src/maxdiffusion/trainers Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -97,7 +97,7 @@ def calculate_tflops(self, pipeline):
9797 return 0
9898
9999 def get_data_shardings (self , mesh ):
100- data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
100+ data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding [ 0 ] ))
101101 data_sharding = {
102102 "latents" : data_sharding ,
103103 "encoder_hidden_states" : data_sharding
@@ -146,7 +146,7 @@ def start_training(self):
146146 # del pipeline.vae
147147
148148 # Generate a sample before training to compare against generated sample after training.
149- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
149+ # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
150150 mesh = pipeline .mesh
151151 data_iterator = self .load_dataset (mesh )
152152
You can’t perform that action at this time.
0 commit comments