Skip to content

Commit 125dcfa

Browse files
fix padding remove extra mask.
1 parent 69a93b9 commit 125dcfa

1 file changed

Lines changed: 25 additions & 21 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,31 +110,40 @@ def _unflatten_heads(tensor, heads):
110110
return tensor
111111

112112

113-
def _reshape_data_for_flash(tensor, heads, flash_block_size):
113+
def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
114114
"""
115115
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
116+
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
117+
blocks is divisible by the number of shards.
116118
"""
117119
if tensor.ndim != 4:
118120
tensor = _unflatten_heads(tensor, heads)
119121

120-
# pad head_dim to 128 if less than that.
122+
# Pad head_dim to 128 if less than that.
121123
kv_size = tensor.shape[-1]
122124
head_dim_pad = 0
123125
if kv_size < 128:
124126
head_dim_pad = 128 - kv_size
125127

126-
# pad seq_len to a multiple of flash_block_size if needed.
128+
# Pad seq_len with sharding constraints.
127129
seq_len = tensor.shape[2]
128-
# remainder
130+
131+
# 1. First, pad seq_len to be a multiple of flash_block_size
129132
rem = seq_len % flash_block_size
130-
seq_len_pad = 0
131133
if rem != 0:
132-
# multiplier
133-
mul = seq_len // flash_block_size
134-
# pad to the closest multiplier of flash_block_size
135-
seq_len_pad = (mul + 1) * flash_block_size - seq_len
134+
seq_len_padded_pre = seq_len + (flash_block_size - rem)
135+
else:
136+
seq_len_padded_pre = seq_len
137+
138+
# 2. Ensure num_blocks is divisible by num_shards
139+
num_blocks = seq_len_padded_pre // flash_block_size
140+
if num_blocks % num_shards != 0:
141+
num_blocks += (num_shards - (num_blocks % num_shards))
136142

137-
if kv_size < 128 or rem != 0:
143+
final_padded_len = num_blocks * flash_block_size
144+
seq_len_pad = final_padded_len - seq_len
145+
146+
if kv_size < 128 or seq_len_pad != 0:
138147
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
139148
tensor = jnp.pad(tensor, npad)
140149

@@ -153,7 +162,7 @@ def _tpu_flash_attention(
153162
) -> jax.Array:
154163
"""TPU Flash Attention"""
155164

156-
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
165+
max_block_size = 768#1024 if dtype == jnp.bfloat16 else 512
157166
if flash_block_sizes:
158167
block_sizes = flash_block_sizes
159168
else:
@@ -168,17 +177,17 @@ def _tpu_flash_attention(
168177
block_kv_dq=min(max_block_size, query.shape[2]),
169178
)
170179

171-
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q)
172-
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute)
173-
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute)
174-
180+
num_fsdp_shards = mesh.shape["fsdp"]
181+
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
182+
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
183+
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
175184
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
176185
kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV))
177186
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
178187
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
179188
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
180189

181-
cp_size=8
190+
cp_size=1
182191

183192
@functools.partial(
184193
jax.jit,
@@ -198,11 +207,6 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
198207

199208
shard_head_size = 1
200209
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2]))
201-
mask &= splash_attention_mask.LocalMask(
202-
shape=(query.shape[2], key.shape[2]),
203-
window_size=(query.shape[2], query.shape[2]),
204-
offset=0
205-
)
206210
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
207211
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
208212
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)

0 commit comments

Comments
 (0)