@@ -248,9 +248,10 @@ def patched_sdpa_mask_recent_torch(
248248 return attention_mask
249249
250250 def patched_sdpa_mask (
251- batch_size : int ,
252- cache_position : torch .Tensor ,
253- kv_length : int ,
251+ batch_size : int = 0 ,
252+ q_length : int = 0 ,
253+ kv_length : int = 0 ,
254+ q_offset : int = 0 ,
254255 kv_offset : int = 0 ,
255256 mask_function : Callable = causal_mask_function ,
256257 attention_mask : torch .Tensor | None = None ,
@@ -262,7 +263,79 @@ def patched_sdpa_mask(
262263 ** kwargs ,
263264 ) -> torch .Tensor | None :
264265 """manual patch for function ``transformers.masking_utils.sdpa_mask``."""
265- q_length = cache_position .shape [0 ]
266+ if isinstance (q_length , torch .Tensor ):
267+ # `cache_position` is deprecated as an arg,
268+ # and will be removed in Transformers v5.6. Please use `q_length` and "
269+ # `q_offset` instead, similarly to `kv_length` and `kv_offset`"
270+ cache_position = q_length
271+ device = q_length .device
272+ q_length = q_length .shape [0 ]
273+
274+ # Potentially pad the 2D mask
275+ padding_mask = prepare_padding_mask (attention_mask , kv_length , kv_offset )
276+
277+ # Under specific conditions, we can avoid materializing the mask
278+ # 1. Causal masks can rely on the `is_causal` argument
279+ # 2. Bidirectional do not need any further processing (no bias)
280+ if allow_is_causal_skip and _ignore_causal_mask_sdpa (
281+ padding_mask , q_length , kv_length , kv_offset , local_size
282+ ):
283+ return None
284+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa (
285+ padding_mask , kv_length , local_size
286+ ):
287+ return None
288+
289+ # Potentially add the padding 2D mask
290+ if padding_mask is not None :
291+ mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
292+
293+ batch_arange = torch .arange (batch_size , device = device )
294+ head_arange = torch .arange (1 , device = device )
295+ # Similar to `kv_arange = torch.arange(start=kv_offset,
296+ # end=kv_offset + kv_length, device=cache_position.device)`
297+ # but without data-dependent slicing (i.e. torch.compile friendly)
298+ kv_arange = torch .arange (kv_length , device = device ) + kv_offset
299+
300+ # Actual mask creation
301+ # Option 1: Fast non-vmap mask creation (default)
302+ # PATCHED
303+ use_vmap = False
304+ if not use_vmap :
305+ # Apply mask function element-wise through broadcasting
306+ attention_mask = mask_function (
307+ * _non_vmap_expansion_sdpa (
308+ batch_arange , head_arange , cache_position , kv_arange
309+ )
310+ )
311+ # Expand the mask to match batch size
312+ # and query length if they weren't used in the mask function
313+ attention_mask = attention_mask .expand (batch_size , - 1 , q_length , kv_length )
314+
315+ # Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
316+ # elif _is_torch_greater_or_equal_than_2_6:
317+ # This creates the 4D mask easily.
318+ # Note that we need this context manager as vmap
319+ # cannot handle slicing a tensor from
320+ # scalar tensor (it internally calls `.item()` which vmap does not allow,
321+ # but this context works around it
322+ # We don't need to add an offset to the mask_function either,
323+ # as we vmap directly the correct indices for k and kv indices
324+ # with TransformGetItemToIndex():
325+ # attention_mask = _vmap_expansion_sdpa(mask_function)(
326+ # batch_arange, head_arange, cache_position, kv_arange
327+ # )
328+
329+ # Option 3: Error out since it indicates that the user did something custom,
330+ # which they shouldn't have (torch<2.6)
331+ else :
332+ raise ValueError (
333+ "The vmap functionality for mask creation "
334+ "is only supported from torch>=2.6. "
335+ "Please update your torch version or use "
336+ "`use_vmap=False` with index-based masks."
337+ )
338+ return attention_mask
266339
267340 # Potentially pad the 2D mask
268341 padding_mask = prepare_padding_mask (attention_mask , kv_length , kv_offset )
@@ -283,46 +356,17 @@ def patched_sdpa_mask(
283356 if padding_mask is not None :
284357 mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
285358
286- batch_arange = torch .arange (batch_size , device = cache_position .device )
287- head_arange = torch .arange (1 , device = cache_position .device )
288- # Similar to `kv_arange = torch.arange(start=kv_offset,
289- # end=kv_offset + kv_length, device=cache_position.device)`
290- # but without data-dependent slicing (i.e. torch.compile friendly)
291- kv_arange = torch .arange (kv_length , device = cache_position .device ) + kv_offset
359+ batch_arange = torch .arange (batch_size , device = device )
360+ head_arange = torch .arange (1 , device = device )
361+ q_arange = torch .arange (q_length , device = device ) + q_offset
362+ kv_arange = torch .arange (kv_length , device = device ) + kv_offset
363+
364+ # Apply mask function element-wise through broadcasting
365+ attention_mask = mask_function (
366+ * _non_vmap_expansion_sdpa (batch_arange , head_arange , q_arange , kv_arange )
367+ )
368+ # Expand the mask to match batch size and query
369+ # length if they weren't used in the mask function
370+ attention_mask = attention_mask .expand (batch_size , - 1 , q_length , kv_length )
292371
293- # Actual mask creation
294- # Option 1: Fast non-vmap mask creation (default)
295- # PATCHED
296- use_vmap = False
297- if not use_vmap :
298- # Apply mask function element-wise through broadcasting
299- attention_mask = mask_function (
300- * _non_vmap_expansion_sdpa (batch_arange , head_arange , cache_position , kv_arange )
301- )
302- # Expand the mask to match batch size
303- # and query length if they weren't used in the mask function
304- attention_mask = attention_mask .expand (batch_size , - 1 , q_length , kv_length )
305-
306- # Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
307- # elif _is_torch_greater_or_equal_than_2_6:
308- # This creates the 4D mask easily.
309- # Note that we need this context manager as vmap cannot handle slicing a tensor from
310- # scalar tensor (it internally calls `.item()` which vmap does not allow,
311- # but this context works around it
312- # We don't need to add an offset to the mask_function either,
313- # as we vmap directly the correct indices for k and kv indices
314- # with TransformGetItemToIndex():
315- # attention_mask = _vmap_expansion_sdpa(mask_function)(
316- # batch_arange, head_arange, cache_position, kv_arange
317- # )
318-
319- # Option 3: Error out since it indicates that the user did something custom,
320- # which they shouldn't have (torch<2.6)
321- else :
322- raise ValueError (
323- "The vmap functionality for mask creation "
324- "is only supported from torch>=2.6. "
325- "Please update your torch version or use "
326- "`use_vmap=False` with index-based masks."
327- )
328372 return attention_mask
0 commit comments