@@ -146,61 +146,106 @@ def patched_eager_mask(
146146 return mask
147147
148148 def patched_sdpa_mask_recent_torch (
149- batch_size : int ,
150- cache_position : torch .Tensor ,
151- kv_length : int ,
149+ batch_size : int = 0 ,
150+ q_length : int = 0 ,
151+ kv_length : int = 0 ,
152+ q_offset : int = 0 ,
152153 kv_offset : int = 0 ,
153154 mask_function : Callable = causal_mask_function ,
154155 attention_mask : Optional [torch .Tensor ] = None ,
155156 local_size : Optional [int ] = None ,
156157 allow_is_causal_skip : bool = True ,
157158 allow_is_bidirectional_skip : bool = False ,
159+ use_vmap : bool = False ,
160+ device : torch .device | str = "cpu" ,
158161 ** kwargs ,
159162 ) -> Optional [torch .Tensor ]:
160163 """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
161- q_length = cache_position .shape [0 ]
162- padding_mask = prepare_padding_mask (
163- attention_mask , kv_length , kv_offset , ** _prepare_padding_mask_kwargs
164- )
165- if allow_is_causal_skip and _ignore_causal_mask_sdpa (
166- padding_mask , q_length , kv_length , kv_offset , local_size
167- ):
168- return None
169- if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa :
170- # transformers<=5.0: 1 parameter, 3 for transformers>5.0
171- n_parameters = len (inspect .signature (_ignore_bidirectional_mask_sdpa ).parameters )
172- if _ignore_bidirectional_mask_sdpa (
173- * [padding_mask , kv_length , kv_offset ][:n_parameters ]
164+ if isinstance (q_length , torch .Tensor ):
165+ # `cache_position` is deprecated as an arg,
166+ # and will be removed in Transformers v5.6. Please use `q_length` and "
167+ # `q_offset` instead, similarly to `kv_length` and `kv_offset`"
168+ q_length , q_offset = q_length .shape [0 ], q_length [0 ].to (device )
169+ device = q_length .device
170+
171+ padding_mask = prepare_padding_mask (
172+ attention_mask , kv_length , kv_offset , ** _prepare_padding_mask_kwargs
173+ )
174+ if allow_is_causal_skip and _ignore_causal_mask_sdpa (
175+ padding_mask , q_length , kv_length , kv_offset , local_size
174176 ):
175177 return None
178+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa :
179+ # transformers<=5.0: 1 parameter, 3 for transformers>5.0
180+ n_parameters = len (
181+ inspect .signature (_ignore_bidirectional_mask_sdpa ).parameters
182+ )
183+ if _ignore_bidirectional_mask_sdpa (
184+ * [padding_mask , kv_length , kv_offset ][:n_parameters ]
185+ ):
186+ return None
176187
177- if mask_function is bidirectional_mask_function :
178- if padding_mask is not None :
179- # used for slicing without data-dependent slicing
180- mask_indices = (
181- torch .arange (kv_length , device = cache_position .device ) + kv_offset
188+ if mask_function is bidirectional_mask_function :
189+ if padding_mask is not None :
190+ # used for slicing without data-dependent slicing
191+ mask_indices = torch .arange (kv_length , device = device ) + kv_offset
192+ return padding_mask [:, None , None , mask_indices ].expand (
193+ - 1 , - 1 , q_length , - 1
194+ )
195+ return torch .ones (
196+ batch_size ,
197+ 1 ,
198+ q_length ,
199+ kv_length ,
200+ dtype = torch .bool ,
201+ device = device ,
182202 )
183- return padding_mask [:, None , None , mask_indices ].expand (- 1 , - 1 , q_length , - 1 )
184- return torch .ones (
185- batch_size ,
186- 1 ,
187- q_length ,
188- kv_length ,
189- dtype = torch .bool ,
190- device = cache_position .device ,
203+
204+ kv_arange = torch .arange (kv_length , device = device )
205+ kv_arange += kv_offset
206+ if padding_mask is not None :
207+ mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
208+ batch_arange = torch .arange (batch_size , device = device )
209+ head_arange = torch .arange (1 , device = device )
210+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
211+ causal_mask = patched__vmap_for_bhqkv (mask_function )(
212+ batch_arange , head_arange , q_length , kv_arange
191213 )
214+ return causal_mask
215+
216+ padding_mask = prepare_padding_mask (attention_mask , kv_length , kv_offset )
217+
218+ # Under specific conditions, we can avoid materializing the mask
219+ # 1. Causal masks can rely on the `is_causal` argument
220+ # 2. Bidirectional do not need any further processing (no bias)
221+ if allow_is_causal_skip and _ignore_causal_mask_sdpa (
222+ padding_mask , q_length , kv_length , kv_offset , local_size
223+ ):
224+ return None
225+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa (
226+ padding_mask , kv_length , local_size
227+ ):
228+ return None
192229
193- kv_arange = torch .arange (kv_length , device = cache_position .device )
194- kv_arange += kv_offset
230+ # Potentially add the padding 2D mask
195231 if padding_mask is not None :
196232 mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
197- batch_arange = torch .arange (batch_size , device = cache_position .device )
198- head_arange = torch .arange (1 , device = cache_position .device )
199- # PATCHED: this line calls the patched version of vmap_for_bhqkv
200- causal_mask = patched__vmap_for_bhqkv (mask_function )(
201- batch_arange , head_arange , cache_position , kv_arange
233+
234+ batch_arange = torch .arange (batch_size , device = device )
235+ head_arange = torch .arange (1 , device = device )
236+ q_arange = torch .arange (q_length , device = device ) + q_offset
237+ kv_arange = torch .arange (kv_length , device = device ) + kv_offset
238+
239+ # Actual mask creation
240+ # Option 1: Fast non-vmap mask creation (default)
241+ # Apply mask function element-wise through broadcasting
242+ attention_mask = mask_function (
243+ * _non_vmap_expansion_sdpa (batch_arange , head_arange , q_arange , kv_arange )
202244 )
203- return causal_mask
245+ # Expand the mask to match batch size and
246+ # query length if they weren't used in the mask function
247+ attention_mask = attention_mask .expand (batch_size , - 1 , q_length , kv_length )
248+ return attention_mask
204249
205250 def patched_sdpa_mask (
206251 batch_size : int ,
0 commit comments