Skip to content

Commit 28dbe57

Browse files
committed
add remat policy. Remove sharding for shard_map splash to lower memory footprint
1 parent 3f2a800 commit 28dbe57

5 files changed

Lines changed: 99 additions & 25 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ transform_images_num_proc: 4
182182
reuse_example_batch: False
183183
enable_data_shuffling: True
184184

185+
# Defines the type of gradient checkpoint to enable.
186+
# NONE - means no gradient checkpoint
187+
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
188+
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
189+
# except for ones that involve batch dimension - that means that all attention and projection
190+
# layers will have gradient checkpoint, but not the backward with respect to the parameters
191+
remat_policy: "NONE"
192+
185193
# checkpoint every number of samples, -1 means don't checkpoint.
186194
checkpoint_every: -1
187195
# enables one replica to read the ckpt then broadcast to the rest

src/maxdiffusion/models/attention_flax.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -187,44 +187,34 @@ def _tpu_flash_attention(
187187
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
188188
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
189189
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
190-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
190+
flash_axis_names_splash_kernel: AxisNames = (HEAD, KV_LENGTH)
191191
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192192
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
193193

194194
shard_head_size = mesh.shape["tensor"]
195195

196-
@functools.partial(
197-
jax.jit,
198-
static_argnames=["multi_head_mask", "shard_head_size"],
199-
)
200-
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
201-
splash_kernel = splash_attention_kernel.make_splash_mha(
202-
mask=multi_head_mask,
203-
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
204-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
205-
block_sizes=block_sizes,
206-
)
207-
return splash_kernel
208-
209-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210-
211-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
212-
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
213-
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
214-
215196
@functools.partial(
216197
shard_map.shard_map,
217198
mesh=mesh,
218199
in_specs=(
219200
q_axis_names,
220201
kv_axis_names,
221202
kv_axis_names,
222-
segment_axis_names_splash_kernel,
223203
),
224204
out_specs=q_axis_names,
225205
check_rep=False,
226206
)
227-
def wrap_flash_attention(query, key, value, splash_kernel):
207+
def wrap_flash_attention(query, key, value):
208+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
209+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
210+
# make_splash_mha is wrapped around shardmap and seq and head is already
211+
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
212+
splash_kernel = splash_attention_kernel.make_splash_mha(
213+
mask=multi_head_mask,
214+
head_shards=1, # the sizes of the axis is sharding over heads
215+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
216+
block_sizes=block_sizes,
217+
)
228218
attention_output = jax.vmap(splash_kernel)(query, key, value)
229219
return attention_output
230220

@@ -236,7 +226,7 @@ def wrap_flash_attention(query, key, value, splash_kernel):
236226
"Warning, batch dimension should be shardable among the devices in data and fsdp"
237227
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
238228
)
239-
x = wrap_flash_attention(query, key, value, splash_kernel)
229+
x = wrap_flash_attention(query, key, value)
240230
x = x[:, :, :query_seq_len, :kv_size]
241231
x = _reshape_heads_to_head_dim(x)
242232

@@ -632,7 +622,7 @@ def __init__(
632622
use_memory_efficient_attention: bool = False,
633623
split_head_dim: bool = False,
634624
attention_kernel: str = "flash",
635-
flash_min_seq_length: int = 4096,
625+
flash_min_seq_length: int = 0,
636626
flash_block_sizes: BlockSizes = None,
637627
mesh: jax.sharding.Mesh = None,
638628
dtype: jnp.dtype = jnp.float32,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from enum import Enum, auto
2+
from typing import Optional
3+
4+
import jax
5+
from flax import nnx
6+
7+
SKIP_GRADIENT_CHECKPOINT_KEY = "skip"
8+
9+
# This class only works with NNX modules.
10+
class GradientCheckpointType(Enum):
11+
"""
12+
Defines the type of the gradient checkpoint we will have
13+
14+
NONE - means no gradient checkpoint
15+
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
16+
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
17+
except for ones that involve batch dimension - that means that all attention and projection
18+
layers will have gradient checkpoint, but not the backward with respect to the parameters
19+
"""
20+
21+
NONE = auto()
22+
FULL = auto()
23+
MATMUL_WITHOUT_BATCH = auto()
24+
25+
@classmethod
26+
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
27+
"""
28+
Constructs the gradient checkpoint type from a string
29+
30+
Args:
31+
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
32+
33+
Returns:
34+
GradientCheckpointType: The policy that corresponds to the string
35+
"""
36+
if s is None:
37+
s = "none"
38+
return GradientCheckpointType[s.upper()]
39+
40+
def to_jax_policy(self):
41+
"""
42+
Converts the gradient checkpoint type to a jax policy
43+
"""
44+
match self:
45+
case GradientCheckpointType.NONE:
46+
return SKIP_GRADIENT_CHECKPOINT_KEY
47+
case GradientCheckpointType.FULL:
48+
return None
49+
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
50+
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
51+
52+
def apply(self, module: nnx.Module) -> nnx.Module:
53+
"""
54+
Applies a gradient checkpoint policy to a module
55+
if no policy is needed, it will return the module as is
56+
57+
Args:
58+
module (nn.Module): the module to apply the policy to
59+
60+
Returns:
61+
nn.Module: the module with the policy applied
62+
"""
63+
policy = self.to_jax_policy()
64+
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
65+
return module
66+
return nnx.remat( # pylint: disable=invalid-name
67+
module,
68+
prevent_cse=False,
69+
policy=policy,
70+
)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from ...normalization_flax import FP32LayerNorm
3333
from ...attention_flax import FlaxWanAttention
34+
from ...gradient_checkpoint import GradientCheckpointType
3435

3536
BlockSizes = common_types.BlockSizes
3637

@@ -356,6 +357,7 @@ def __init__(
356357
weights_dtype: jnp.dtype = jnp.float32,
357358
precision: jax.lax.Precision = None,
358359
attention: str = "dot_product",
360+
remat_policy: str = "None"
359361
):
360362
inner_dim = num_attention_heads * attention_head_dim
361363
out_channels = out_channels or in_channels
@@ -417,6 +419,8 @@ def init_block(rngs):
417419
attention=attention,
418420
)
419421

422+
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
423+
420424
self.blocks = init_block(rngs)
421425

422426
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
@@ -469,8 +473,9 @@ def scan_fn(carry, block):
469473
return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
470474

471475
initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
476+
rematted_block_forward = self.gradient_checkpoint.apply(scan_fn)
472477
final_carry = nnx.scan(
473-
scan_fn,
478+
rematted_block_forward,
474479
length=self.num_layers,
475480
in_axes=(nnx.Carry, 0),
476481
out_axes=nnx.Carry,

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7878
wan_config["attention"] = config.attention
7979
wan_config["precision"] = get_precision(config)
8080
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
81+
wan_config["remat_policy"] = config.remat_policy
8182

8283
# 2. eval_shape - will not use flops or create weights on device
8384
# thus not using HBM memory.

0 commit comments

Comments
 (0)