Skip to content

Commit 37c5164

Browse files
committed
bug
1 parent e93c73c commit 37c5164

2 files changed

Lines changed: 88 additions & 42 deletions

File tree

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,18 @@ def test_tiny_llm_run_static(self):
5151
@requires_torch("2.8")
5252
def test_tiny_llm_export_static(self):
5353
data = get_tiny_llm(use_static_cache=True)
54-
model, inputs = data["model"], data["inputs"]
54+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
55+
if "cache_position" in inputs:
56+
del inputs["cache_position"]
57+
del ds["cache_position"]
5558
expected = model(**copy.deepcopy(inputs))
56-
self.assertEqual(
57-
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
58-
)
59+
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
5960
with torch_export_patches(patch_transformers=True, stop_if_static=0):
6061
ep = torch.export.export(
6162
model,
6263
(),
6364
kwargs=copy.deepcopy(inputs),
64-
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
65+
dynamic_shapes=use_dyn_not_str(ds),
6566
)
6667
got = ep.module()(**inputs)
6768
self.assertEqualArrayAny(expected, got)

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -146,61 +146,106 @@ def patched_eager_mask(
146146
return mask
147147

148148
def patched_sdpa_mask_recent_torch(
149-
batch_size: int,
150-
cache_position: torch.Tensor,
151-
kv_length: int,
149+
batch_size: int = 0,
150+
q_length: int = 0,
151+
kv_length: int = 0,
152+
q_offset: int = 0,
152153
kv_offset: int = 0,
153154
mask_function: Callable = causal_mask_function,
154155
attention_mask: Optional[torch.Tensor] = None,
155156
local_size: Optional[int] = None,
156157
allow_is_causal_skip: bool = True,
157158
allow_is_bidirectional_skip: bool = False,
159+
use_vmap: bool = False,
160+
device: torch.device | str = "cpu",
158161
**kwargs,
159162
) -> Optional[torch.Tensor]:
160163
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
161-
q_length = cache_position.shape[0]
162-
padding_mask = prepare_padding_mask(
163-
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
164-
)
165-
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
166-
padding_mask, q_length, kv_length, kv_offset, local_size
167-
):
168-
return None
169-
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa:
170-
# transformers<=5.0: 1 parameter, 3 for transformers>5.0
171-
n_parameters = len(inspect.signature(_ignore_bidirectional_mask_sdpa).parameters)
172-
if _ignore_bidirectional_mask_sdpa(
173-
*[padding_mask, kv_length, kv_offset][:n_parameters]
164+
if isinstance(q_length, torch.Tensor):
165+
# `cache_position` is deprecated as an arg,
166+
# and will be removed in Transformers v5.6. Please use `q_length` and "
167+
# `q_offset` instead, similarly to `kv_length` and `kv_offset`"
168+
q_length, q_offset = q_length.shape[0], q_length[0].to(device)
169+
device = q_length.device
170+
171+
padding_mask = prepare_padding_mask(
172+
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
173+
)
174+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
175+
padding_mask, q_length, kv_length, kv_offset, local_size
174176
):
175177
return None
178+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa:
179+
# transformers<=5.0: 1 parameter, 3 for transformers>5.0
180+
n_parameters = len(
181+
inspect.signature(_ignore_bidirectional_mask_sdpa).parameters
182+
)
183+
if _ignore_bidirectional_mask_sdpa(
184+
*[padding_mask, kv_length, kv_offset][:n_parameters]
185+
):
186+
return None
176187

177-
if mask_function is bidirectional_mask_function:
178-
if padding_mask is not None:
179-
# used for slicing without data-dependent slicing
180-
mask_indices = (
181-
torch.arange(kv_length, device=cache_position.device) + kv_offset
188+
if mask_function is bidirectional_mask_function:
189+
if padding_mask is not None:
190+
# used for slicing without data-dependent slicing
191+
mask_indices = torch.arange(kv_length, device=device) + kv_offset
192+
return padding_mask[:, None, None, mask_indices].expand(
193+
-1, -1, q_length, -1
194+
)
195+
return torch.ones(
196+
batch_size,
197+
1,
198+
q_length,
199+
kv_length,
200+
dtype=torch.bool,
201+
device=device,
182202
)
183-
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
184-
return torch.ones(
185-
batch_size,
186-
1,
187-
q_length,
188-
kv_length,
189-
dtype=torch.bool,
190-
device=cache_position.device,
203+
204+
kv_arange = torch.arange(kv_length, device=device)
205+
kv_arange += kv_offset
206+
if padding_mask is not None:
207+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
208+
batch_arange = torch.arange(batch_size, device=device)
209+
head_arange = torch.arange(1, device=device)
210+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
211+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
212+
batch_arange, head_arange, q_length, kv_arange
191213
)
214+
return causal_mask
215+
216+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
217+
218+
# Under specific conditions, we can avoid materializing the mask
219+
# 1. Causal masks can rely on the `is_causal` argument
220+
# 2. Bidirectional do not need any further processing (no bias)
221+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
222+
padding_mask, q_length, kv_length, kv_offset, local_size
223+
):
224+
return None
225+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(
226+
padding_mask, kv_length, local_size
227+
):
228+
return None
192229

193-
kv_arange = torch.arange(kv_length, device=cache_position.device)
194-
kv_arange += kv_offset
230+
# Potentially add the padding 2D mask
195231
if padding_mask is not None:
196232
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
197-
batch_arange = torch.arange(batch_size, device=cache_position.device)
198-
head_arange = torch.arange(1, device=cache_position.device)
199-
# PATCHED: this line calls the patched version of vmap_for_bhqkv
200-
causal_mask = patched__vmap_for_bhqkv(mask_function)(
201-
batch_arange, head_arange, cache_position, kv_arange
233+
234+
batch_arange = torch.arange(batch_size, device=device)
235+
head_arange = torch.arange(1, device=device)
236+
q_arange = torch.arange(q_length, device=device) + q_offset
237+
kv_arange = torch.arange(kv_length, device=device) + kv_offset
238+
239+
# Actual mask creation
240+
# Option 1: Fast non-vmap mask creation (default)
241+
# Apply mask function element-wise through broadcasting
242+
attention_mask = mask_function(
243+
*_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange)
202244
)
203-
return causal_mask
245+
# Expand the mask to match batch size and
246+
# query length if they weren't used in the mask function
247+
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
248+
return attention_mask
204249

205250
def patched_sdpa_mask(
206251
batch_size: int,

0 commit comments

Comments
 (0)