|
1 | 1 | import unittest |
2 | | -from onnx_diagnostic.ext_test_case import ( |
3 | | - ExtTestCase, |
4 | | - requires_torch, |
5 | | - requires_transformers, |
6 | | - skipif_ci_windows, |
7 | | - ignore_warnings, |
8 | | - hide_stdout, |
9 | | -) |
10 | | -from onnx_diagnostic.helpers import string_type |
11 | | -from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( |
12 | | - torch_export_patches, |
13 | | -) |
| 2 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
14 | 3 |
|
15 | 4 |
|
16 | 5 | class TestOnnxExportErrors(ExtTestCase): |
17 | | - @requires_transformers("4.49.999") |
18 | | - @skipif_ci_windows("not working on Windows") |
19 | | - @ignore_warnings(UserWarning) |
20 | | - @hide_stdout() |
21 | | - def test_pytree_flatten_mamba_cache(self): |
22 | | - import torch |
23 | | - import torch.utils._pytree as py_pytree |
24 | | - |
25 | | - try: |
26 | | - from transformers.models.mamba.modeling_mamba import MambaCache |
27 | | - except ImportError: |
28 | | - from transformers.cache_utils import MambaCache |
29 | | - |
30 | | - class _config: |
31 | | - def __init__(self): |
32 | | - self.intermediate_size = 8 |
33 | | - self.state_size = 16 |
34 | | - self.conv_kernel = 32 |
35 | | - self.num_hidden_layers = 64 |
36 | | - self.dtype = torch.float16 |
37 | | - |
38 | | - cache = MambaCache(_config(), max_batch_size=1, device="cpu") |
39 | | - |
40 | | - with torch_export_patches(verbose=1): |
41 | | - values, spec = py_pytree.tree_flatten(cache) |
42 | | - cache2 = py_pytree.tree_unflatten(values, spec) |
43 | | - self.assertEqual(cache.max_batch_size, cache2.max_batch_size) |
44 | | - self.assertEqual(cache.intermediate_size, cache2.intermediate_size) |
45 | | - self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size) |
46 | | - self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size) |
47 | | - self.assertEqualArrayAny(cache.conv_states, cache2.conv_states) |
48 | | - self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states) |
49 | | - |
50 | | - @requires_transformers("4.50") |
51 | | - @requires_torch("2.7") |
52 | | - @skipif_ci_windows("not working on Windows") |
53 | | - @ignore_warnings(UserWarning) |
54 | | - @hide_stdout() |
55 | | - def test_exportable_mamba_cache(self): |
56 | | - import torch |
57 | | - from transformers.models.mamba.modeling_mamba import MambaCache |
58 | | - |
59 | | - class _config: |
60 | | - def __init__(self): |
61 | | - self.intermediate_size = 8 |
62 | | - self.state_size = 16 |
63 | | - self.conv_kernel = 32 |
64 | | - self.num_hidden_layers = 64 |
65 | | - self.dtype = torch.float16 |
66 | | - |
67 | | - class Model(torch.nn.Module): |
68 | | - def forward(self, x: torch.Tensor, cache: MambaCache): |
69 | | - x1 = cache.ssm_states[0] + x |
70 | | - x2 = cache.conv_states[0][:, :, ::2] + x1 |
71 | | - return x2 |
72 | | - |
73 | | - cache = MambaCache(_config(), max_batch_size=1, device="cpu") |
74 | | - # MambaCache was updated in 4.50 |
75 | | - self.assertEqual( |
76 | | - "MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])", |
77 | | - string_type(cache), |
78 | | - ) |
79 | | - x = torch.ones(2, 8, 16).to(torch.float16) |
80 | | - model = Model() |
81 | | - model(x, cache) |
82 | | - |
83 | | - with torch_export_patches(verbose=1, patch_transformers=True): |
84 | | - cache = MambaCache(_config(), max_batch_size=1, device="cpu") |
85 | | - torch.export.export(Model(), (x, cache)) |
86 | | - |
87 | | - @requires_transformers("4.49.999") |
88 | | - @skipif_ci_windows("not working on Windows") |
89 | | - @ignore_warnings(UserWarning) |
90 | | - def test_exportable_mamba_cache_dynamic(self): |
91 | | - import torch |
92 | | - from transformers.models.mamba.modeling_mamba import MambaCache |
93 | | - |
94 | | - class _config: |
95 | | - def __init__(self): |
96 | | - self.intermediate_size = 8 |
97 | | - self.state_size = 16 |
98 | | - self.conv_kernel = 32 |
99 | | - self.num_hidden_layers = 2 |
100 | | - self.dtype = torch.float16 |
101 | | - |
102 | | - class Model(torch.nn.Module): |
103 | | - def forward(self, x: torch.Tensor, cache: MambaCache): |
104 | | - x1 = cache.ssm_states[0] + x |
105 | | - x2 = cache.conv_states[0][:, :, ::2] + x1 |
106 | | - return x2 |
107 | | - |
108 | | - cache = MambaCache(_config(), max_batch_size=1, device="cpu") |
109 | | - self.assertEqual( |
110 | | - string_type(cache), |
111 | | - "MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])", |
112 | | - ) |
113 | | - x = torch.ones(2, 8, 16).to(torch.float16) |
114 | | - model = Model() |
115 | | - model(x, cache) |
116 | | - DYN = torch.export.Dim.DYNAMIC |
117 | | - |
118 | | - with torch_export_patches(patch_transformers=True): |
119 | | - cache = MambaCache(_config(), max_batch_size=2, device="cpu") |
120 | | - torch.export.export( |
121 | | - Model(), |
122 | | - (x, cache), |
123 | | - dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]), |
124 | | - ) |
125 | | - |
126 | 6 | @ignore_warnings(UserWarning) |
127 | 7 | def test_exportable_dynamic_shapes_constraints(self): |
128 | 8 | import torch |
|
0 commit comments