@@ -110,31 +110,40 @@ def _unflatten_heads(tensor, heads):
110110 return tensor
111111
112112
113- def _reshape_data_for_flash (tensor , heads , flash_block_size ):
113+ def _reshape_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
114114 """
115115 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
116+ Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
117+ blocks is divisible by the number of shards.
116118 """
117119 if tensor .ndim != 4 :
118120 tensor = _unflatten_heads (tensor , heads )
119121
120- # pad head_dim to 128 if less than that.
122+ # Pad head_dim to 128 if less than that.
121123 kv_size = tensor .shape [- 1 ]
122124 head_dim_pad = 0
123125 if kv_size < 128 :
124126 head_dim_pad = 128 - kv_size
125127
126- # pad seq_len to a multiple of flash_block_size if needed .
128+ # Pad seq_len with sharding constraints .
127129 seq_len = tensor .shape [2 ]
128- # remainder
130+
131+ # 1. First, pad seq_len to be a multiple of flash_block_size
129132 rem = seq_len % flash_block_size
130- seq_len_pad = 0
131133 if rem != 0 :
132- # multiplier
133- mul = seq_len // flash_block_size
134- # pad to the closest multiplier of flash_block_size
135- seq_len_pad = (mul + 1 ) * flash_block_size - seq_len
134+ seq_len_padded_pre = seq_len + (flash_block_size - rem )
135+ else :
136+ seq_len_padded_pre = seq_len
137+
138+ # 2. Ensure num_blocks is divisible by num_shards
139+ num_blocks = seq_len_padded_pre // flash_block_size
140+ if num_blocks % num_shards != 0 :
141+ num_blocks += (num_shards - (num_blocks % num_shards ))
136142
137- if kv_size < 128 or rem != 0 :
143+ final_padded_len = num_blocks * flash_block_size
144+ seq_len_pad = final_padded_len - seq_len
145+
146+ if kv_size < 128 or seq_len_pad != 0 :
138147 npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
139148 tensor = jnp .pad (tensor , npad )
140149
@@ -153,7 +162,7 @@ def _tpu_flash_attention(
153162) -> jax .Array :
154163 """TPU Flash Attention"""
155164
156- max_block_size = 1024 if dtype == jnp .bfloat16 else 512
165+ max_block_size = 768 # 1024 if dtype == jnp.bfloat16 else 512
157166 if flash_block_sizes :
158167 block_sizes = flash_block_sizes
159168 else :
@@ -168,17 +177,17 @@ def _tpu_flash_attention(
168177 block_kv_dq = min (max_block_size , query .shape [2 ]),
169178 )
170179
171- query , kv_size , query_seq_len = _reshape_data_for_flash ( query , heads , block_sizes . block_q )
172- key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute )
173- value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute )
174-
180+ num_fsdp_shards = mesh . shape [ "fsdp" ]
181+ query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
182+ key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
183+ value , _ , _ = _reshape_data_for_flash ( value , heads , block_sizes . block_kv_compute , num_fsdp_shards )
175184 axis_names = nn .logical_to_mesh_axes (flash_axis_names )
176185 kv_axis_names = nn .logical_to_mesh_axes ((BATCH , HEAD , None , D_KV ))
177186 flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
178187 axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
179188 named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
180189
181- cp_size = 8
190+ cp_size = 1
182191
183192 @functools .partial (
184193 jax .jit ,
@@ -198,11 +207,6 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
198207
199208 shard_head_size = 1
200209 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], query .shape [2 ]))
201- mask &= splash_attention_mask .LocalMask (
202- shape = (query .shape [2 ], key .shape [2 ]),
203- window_size = (query .shape [2 ], query .shape [2 ]),
204- offset = 0
205- )
206210 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
207211 splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
208212 segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
0 commit comments