Skip to content

Commit 5d39ae1

Browse files
committed
fix
1 parent 37c5164 commit 5d39ae1

3 files changed

Lines changed: 92 additions & 48 deletions

File tree

_unittests/ut_export/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(self, x, y):
4646

4747
@hide_stdout()
4848
@ignore_warnings(FutureWarning)
49-
@requires_transformers("4.50")
49+
@requires_transformers("4.57")
5050
def test_tiny_llm_to_onnx(self):
5151
import onnxruntime
5252

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_sdpa_mask_patched(self):
7474
patched_sdpa_mask = patch_transformers.patched_sdpa_mask
7575
kwargs = {
7676
"batch_size": 1,
77-
"cache_position": torch.tensor([3], dtype=torch.int64),
77+
"q_length": torch.tensor([3], dtype=torch.int64),
7878
"kv_length": 4,
7979
"kv_offset": 0,
8080
"mask_function": transformers.masking_utils.causal_mask_function,
@@ -89,7 +89,7 @@ def test_sdpa_mask_patched(self):
8989

9090
kwargs = {
9191
"batch_size": 1,
92-
"cache_position": torch.tensor([3], dtype=torch.int64),
92+
"q_length": torch.tensor([3], dtype=torch.int64),
9393
"kv_length": 4,
9494
"kv_offset": 0,
9595
"mask_function": transformers.masking_utils.causal_mask_function,

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,10 @@ def patched_sdpa_mask_recent_torch(
248248
return attention_mask
249249

250250
def patched_sdpa_mask(
251-
batch_size: int,
252-
cache_position: torch.Tensor,
253-
kv_length: int,
251+
batch_size: int = 0,
252+
q_length: int = 0,
253+
kv_length: int = 0,
254+
q_offset: int = 0,
254255
kv_offset: int = 0,
255256
mask_function: Callable = causal_mask_function,
256257
attention_mask: torch.Tensor | None = None,
@@ -262,7 +263,79 @@ def patched_sdpa_mask(
262263
**kwargs,
263264
) -> torch.Tensor | None:
264265
"""manual patch for function ``transformers.masking_utils.sdpa_mask``."""
265-
q_length = cache_position.shape[0]
266+
if isinstance(q_length, torch.Tensor):
267+
# `cache_position` is deprecated as an arg,
268+
# and will be removed in Transformers v5.6. Please use `q_length` and "
269+
# `q_offset` instead, similarly to `kv_length` and `kv_offset`"
270+
cache_position = q_length
271+
device = q_length.device
272+
q_length = q_length.shape[0]
273+
274+
# Potentially pad the 2D mask
275+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
276+
277+
# Under specific conditions, we can avoid materializing the mask
278+
# 1. Causal masks can rely on the `is_causal` argument
279+
# 2. Bidirectional do not need any further processing (no bias)
280+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
281+
padding_mask, q_length, kv_length, kv_offset, local_size
282+
):
283+
return None
284+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(
285+
padding_mask, kv_length, local_size
286+
):
287+
return None
288+
289+
# Potentially add the padding 2D mask
290+
if padding_mask is not None:
291+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
292+
293+
batch_arange = torch.arange(batch_size, device=device)
294+
head_arange = torch.arange(1, device=device)
295+
# Similar to `kv_arange = torch.arange(start=kv_offset,
296+
# end=kv_offset + kv_length, device=cache_position.device)`
297+
# but without data-dependent slicing (i.e. torch.compile friendly)
298+
kv_arange = torch.arange(kv_length, device=device) + kv_offset
299+
300+
# Actual mask creation
301+
# Option 1: Fast non-vmap mask creation (default)
302+
# PATCHED
303+
use_vmap = False
304+
if not use_vmap:
305+
# Apply mask function element-wise through broadcasting
306+
attention_mask = mask_function(
307+
*_non_vmap_expansion_sdpa(
308+
batch_arange, head_arange, cache_position, kv_arange
309+
)
310+
)
311+
# Expand the mask to match batch size
312+
# and query length if they weren't used in the mask function
313+
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
314+
315+
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
316+
# elif _is_torch_greater_or_equal_than_2_6:
317+
# This creates the 4D mask easily.
318+
# Note that we need this context manager as vmap
319+
# cannot handle slicing a tensor from
320+
# scalar tensor (it internally calls `.item()` which vmap does not allow,
321+
# but this context works around it
322+
# We don't need to add an offset to the mask_function either,
323+
# as we vmap directly the correct indices for k and kv indices
324+
# with TransformGetItemToIndex():
325+
# attention_mask = _vmap_expansion_sdpa(mask_function)(
326+
# batch_arange, head_arange, cache_position, kv_arange
327+
# )
328+
329+
# Option 3: Error out since it indicates that the user did something custom,
330+
# which they shouldn't have (torch<2.6)
331+
else:
332+
raise ValueError(
333+
"The vmap functionality for mask creation "
334+
"is only supported from torch>=2.6. "
335+
"Please update your torch version or use "
336+
"`use_vmap=False` with index-based masks."
337+
)
338+
return attention_mask
266339

267340
# Potentially pad the 2D mask
268341
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
@@ -283,46 +356,17 @@ def patched_sdpa_mask(
283356
if padding_mask is not None:
284357
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
285358

286-
batch_arange = torch.arange(batch_size, device=cache_position.device)
287-
head_arange = torch.arange(1, device=cache_position.device)
288-
# Similar to `kv_arange = torch.arange(start=kv_offset,
289-
# end=kv_offset + kv_length, device=cache_position.device)`
290-
# but without data-dependent slicing (i.e. torch.compile friendly)
291-
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
359+
batch_arange = torch.arange(batch_size, device=device)
360+
head_arange = torch.arange(1, device=device)
361+
q_arange = torch.arange(q_length, device=device) + q_offset
362+
kv_arange = torch.arange(kv_length, device=device) + kv_offset
363+
364+
# Apply mask function element-wise through broadcasting
365+
attention_mask = mask_function(
366+
*_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange)
367+
)
368+
# Expand the mask to match batch size and query
369+
# length if they weren't used in the mask function
370+
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
292371

293-
# Actual mask creation
294-
# Option 1: Fast non-vmap mask creation (default)
295-
# PATCHED
296-
use_vmap = False
297-
if not use_vmap:
298-
# Apply mask function element-wise through broadcasting
299-
attention_mask = mask_function(
300-
*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange)
301-
)
302-
# Expand the mask to match batch size
303-
# and query length if they weren't used in the mask function
304-
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
305-
306-
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
307-
# elif _is_torch_greater_or_equal_than_2_6:
308-
# This creates the 4D mask easily.
309-
# Note that we need this context manager as vmap cannot handle slicing a tensor from
310-
# scalar tensor (it internally calls `.item()` which vmap does not allow,
311-
# but this context works around it
312-
# We don't need to add an offset to the mask_function either,
313-
# as we vmap directly the correct indices for k and kv indices
314-
# with TransformGetItemToIndex():
315-
# attention_mask = _vmap_expansion_sdpa(mask_function)(
316-
# batch_arange, head_arange, cache_position, kv_arange
317-
# )
318-
319-
# Option 3: Error out since it indicates that the user did something custom,
320-
# which they shouldn't have (torch<2.6)
321-
else:
322-
raise ValueError(
323-
"The vmap functionality for mask creation "
324-
"is only supported from torch>=2.6. "
325-
"Please update your torch version or use "
326-
"`use_vmap=False` with index-based masks."
327-
)
328372
return attention_mask

0 commit comments

Comments
 (0)