@@ -187,44 +187,34 @@ def _tpu_flash_attention(
187187 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
188188 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
189189 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
190- flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
190+ flash_axis_names_splash_kernel : AxisNames = (HEAD , KV_LENGTH )
191191 axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
192192 named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
193193
194194 shard_head_size = mesh .shape ["tensor" ]
195195
196- @functools .partial (
197- jax .jit ,
198- static_argnames = ["multi_head_mask" , "shard_head_size" ],
199- )
200- def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
201- splash_kernel = splash_attention_kernel .make_splash_mha (
202- mask = multi_head_mask ,
203- head_shards = shard_head_size , # the sizes of the axis is sharding over heads
204- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
205- block_sizes = block_sizes ,
206- )
207- return splash_kernel
208-
209- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
210-
211- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
212- splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
213- segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
214-
215196 @functools .partial (
216197 shard_map .shard_map ,
217198 mesh = mesh ,
218199 in_specs = (
219200 q_axis_names ,
220201 kv_axis_names ,
221202 kv_axis_names ,
222- segment_axis_names_splash_kernel ,
223203 ),
224204 out_specs = q_axis_names ,
225205 check_rep = False ,
226206 )
227- def wrap_flash_attention (query , key , value , splash_kernel ):
207+ def wrap_flash_attention (query , key , value ):
208+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
209+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
210+ # make_splash_mha is wrapped around shardmap and seq and head is already
211+ # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
212+ splash_kernel = splash_attention_kernel .make_splash_mha (
213+ mask = multi_head_mask ,
214+ head_shards = 1 , # the sizes of the axis is sharding over heads
215+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
216+ block_sizes = block_sizes ,
217+ )
228218 attention_output = jax .vmap (splash_kernel )(query , key , value )
229219 return attention_output
230220
@@ -236,7 +226,7 @@ def wrap_flash_attention(query, key, value, splash_kernel):
236226 "Warning, batch dimension should be shardable among the devices in data and fsdp"
237227 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
238228 )
239- x = wrap_flash_attention (query , key , value , splash_kernel )
229+ x = wrap_flash_attention (query , key , value )
240230 x = x [:, :, :query_seq_len , :kv_size ]
241231 x = _reshape_heads_to_head_dim (x )
242232
@@ -632,7 +622,7 @@ def __init__(
632622 use_memory_efficient_attention : bool = False ,
633623 split_head_dim : bool = False ,
634624 attention_kernel : str = "flash" ,
635- flash_min_seq_length : int = 4096 ,
625+ flash_min_seq_length : int = 0 ,
636626 flash_block_sizes : BlockSizes = None ,
637627 mesh : jax .sharding .Mesh = None ,
638628 dtype : jnp .dtype = jnp .float32 ,
0 commit comments