Skip to content

Commit 42c3920

Browse files
committed
linting.
1 parent 3b729af commit 42c3920

6 files changed

Lines changed: 15 additions & 25 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
from abc import ABC
18-
from flax import nnx
1918
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2019
from ..pipelines.wan.wan_pipeline import WanPipeline
2120
from .. import max_logging, max_utils

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _tpu_flash_attention(
202202
def wrap_flash_attention(query, key, value):
203203
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
204204
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
205-
# make_splash_mha is wrapped around shardmap and seq and head is already
205+
# make_splash_mha is wrapped around shardmap and seq and head is already
206206
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
207207
splash_kernel = splash_attention_kernel.make_splash_mha(
208208
mask=multi_head_mask,

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
SKIP_GRADIENT_CHECKPOINT_KEY = "skip"
88

9-
# This class only works with NNX modules.
9+
10+
# This class only works with NNX modules.
1011
class GradientCheckpointType(Enum):
1112
"""
1213
Defines the type of the gradient checkpoint we will have

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
weights_dtype: jnp.dtype = jnp.float32,
365365
precision: jax.lax.Precision = None,
366366
attention: str = "dot_product",
367-
remat_policy: str = "None"
367+
remat_policy: str = "None",
368368
):
369369
inner_dim = num_attention_heads * attention_head_dim
370370
out_channels = out_channels or in_channels
@@ -383,13 +383,7 @@ def __init__(
383383
precision=precision,
384384
kernel_init=nnx.with_partitioning(
385385
nnx.initializers.xavier_uniform(),
386-
(
387-
None,
388-
None,
389-
None,
390-
None,
391-
"conv_out"
392-
),
386+
(None, None, None, None, "conv_out"),
393387
),
394388
)
395389

src/maxdiffusion/multihost_dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,4 @@ def __iter__(self):
114114
return self
115115

116116
def __next__(self):
117-
return get_next_batch_sharded(self.local_iterator, self.global_mesh)
117+
return get_next_batch_sharded(self.local_iterator, self.global_mesh)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,18 @@
3737
from skimage.metrics import structural_similarity as ssim
3838
from flax.training import train_state
3939

40+
4041
class TrainState(train_state.TrainState):
4142
graphdef: nnx.GraphDef
4243
rest_of_state: nnx.State
4344

45+
4446
def _to_array(x):
4547
if not isinstance(x, jax.Array):
4648
x = jnp.asarray(x)
4749
return x
4850

51+
4952
def generate_sample(config, pipeline, filename_prefix):
5053
"""
5154
Generates a video to validate training did not corrupt the model
@@ -79,7 +82,6 @@ def __init__(self, config):
7982
if config.train_text_encoder:
8083
raise ValueError("this script currently doesn't support training text_encoders")
8184

82-
#self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
8385
self.global_batch_size = config.per_device_batch_size * jax.device_count()
8486

8587
def post_training_steps(self, pipeline, params, train_states, msg=""):
@@ -95,13 +97,10 @@ def create_scheduler(self):
9597
def calculate_tflops(self, pipeline):
9698
max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...")
9799
return 0
98-
100+
99101
def get_data_shardings(self, mesh):
100102
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
101-
data_sharding = {
102-
"latents" : data_sharding,
103-
"encoder_hidden_states" : data_sharding
104-
}
103+
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
105104
return data_sharding
106105

107106
def load_dataset(self, mesh):
@@ -167,11 +166,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
167166

168167
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
169168
state = TrainState.create(
170-
apply_fn=graphdef.apply,
171-
params=params,
172-
tx=optimizer,
173-
graphdef=graphdef,
174-
rest_of_state=rest_of_state
169+
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
175170
)
176171
state = jax.tree.map(_to_array, state)
177172
state_spec = nnx.get_partition_spec(state)
@@ -196,8 +191,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
196191

197192
p_train_step = jax.jit(
198193
functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config),
199-
in_shardings = (state_shardings, data_shardings, None, None),
200-
out_shardings = (state_shardings, None, None, None),
194+
in_shardings=(state_shardings, data_shardings, None, None),
195+
out_shardings=(state_shardings, None, None, None),
201196
donate_argnums=(0,),
202197
)
203198
rng = jax.random.key(self.config.seed)
@@ -284,6 +279,7 @@ def loss_fn(params):
284279
loss = jnp.mean(loss)
285280

286281
return loss
282+
287283
grad_fn = nnx.value_and_grad(loss_fn)
288284
loss, grads = grad_fn(state.params)
289285
new_state = state.apply_gradients(grads=grads)

0 commit comments

Comments
 (0)