Skip to content

Commit 340d7c4

Browse files
committed
set data sharding correctly for gbs < 1
1 parent d3b50c8 commit 340d7c4

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def calculate_tflops(self, pipeline):
9797
return 0
9898

9999
def get_data_shardings(self, mesh):
100-
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
100+
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding[0]))
101101
data_sharding = {
102102
"latents" : data_sharding,
103103
"encoder_hidden_states" : data_sharding
@@ -146,7 +146,7 @@ def start_training(self):
146146
# del pipeline.vae
147147

148148
# Generate a sample before training to compare against generated sample after training.
149-
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
149+
#pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
150150
mesh = pipeline.mesh
151151
data_iterator = self.load_dataset(mesh)
152152

0 commit comments

Comments
 (0)