Skip to content

Commit 69f87fe

Browse files
committed
remove mamba
1 parent 238dc2b commit 69f87fe

10 files changed

Lines changed: 4 additions & 475 deletions

File tree

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,6 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
149149
self.string_type(c2, with_shape=True),
150150
)
151151

152-
@requires_transformers("4.51") # the structure changes
153-
def test_make_mamba_cache(self):
154-
cache = make_mamba_cache(
155-
[
156-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
157-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
158-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
159-
]
160-
)
161-
text = self.string_type(cache, with_shape=True)
162-
self.assertEqual(
163-
"MambaCache(conv_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4], "
164-
"ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])",
165-
text,
166-
)
167-
self.assertEqual(0, max_diff(cache, cache)["abs"])
168-
169152
@unittest.skipIf(
170153
not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5"
171154
)

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx
55
import torch
66
import transformers
7-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
88
from onnx_diagnostic.helpers import max_diff, string_type
99
from onnx_diagnostic.helpers.torch_helper import (
1010
dummy_llm,
@@ -22,7 +22,6 @@
2222
from onnx_diagnostic.helpers.cache_helper import (
2323
make_dynamic_cache,
2424
make_encoder_decoder_cache,
25-
make_mamba_cache,
2625
make_sliding_window_cache,
2726
CacheKeyValue,
2827
)
@@ -313,24 +312,6 @@ def test_torch_deepcopy_cache_dce(self):
313312
self.assertEqual(hash1, hash2)
314313
self.assertGreater(torch_tensor_size(cc), 1)
315314

316-
@requires_torch("4.50")
317-
def test_torch_deepcopy_mamba_cache(self):
318-
cache = make_mamba_cache(
319-
[
320-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
321-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
322-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
323-
]
324-
)
325-
at = torch_deepcopy(cache)
326-
self.assertEqual(type(cache), type(at))
327-
self.assertEqual(max_diff(cache, at)["abs"], 0)
328-
hash1 = string_type(at, with_shape=True, with_min_max=True)
329-
cache.conv_states[0] += 1000
330-
hash2 = string_type(at, with_shape=True, with_min_max=True)
331-
self.assertEqual(hash1, hash2)
332-
self.assertGreater(torch_tensor_size(cache), 1)
333-
334315
def test_torch_deepcopy_base_model_outputs(self):
335316
bo = transformers.modeling_outputs.BaseModelOutput(
336317
last_hidden_state=torch.rand((4, 4, 4))

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
55
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
66
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
77
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -257,22 +257,6 @@ def test_sentence_similary(self):
257257
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
258258
)
259259

260-
@hide_stdout()
261-
def test_falcon_mamba_dev(self):
262-
mid = "tiiuae/falcon-mamba-tiny-dev"
263-
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
264-
self.assertEqual(data["task"], "text-generation")
265-
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
266-
model(**inputs)
267-
model(**data["inputs2"])
268-
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
269-
if not has_transformers("5.3.99"):
270-
raise unittest.SkipTest("The model has control flow.")
271-
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
272-
torch.export.export(
273-
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
274-
)
275-
276260

277261
if __name__ == "__main__":
278262
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -683,80 +683,6 @@ def mean_pooling(model_output, attention_mask):
683683
print("Sentence embeddings:")
684684
print(sentence_embeddings)
685685

686-
@never_test()
687-
def test_falcon_mamba_dev(self):
688-
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev
689-
# https://huggingface.co/tiiuae/falcon-mamba-tiny-dev
690-
691-
from transformers import AutoTokenizer
692-
import transformers
693-
import torch
694-
695-
model = "tiiuae/falcon-mamba-tiny-dev"
696-
697-
tokenizer = AutoTokenizer.from_pretrained(model)
698-
pipeline = transformers.pipeline(
699-
"text-generation",
700-
model=model,
701-
tokenizer=tokenizer,
702-
dtype=torch.bfloat16,
703-
trust_remote_code=True,
704-
device_map="auto",
705-
)
706-
print()
707-
with steal_forward(pipeline.model):
708-
sequences = pipeline(
709-
"Girafatron is obsessed with giraffes, "
710-
"the most glorious animal on the face of this Earth. "
711-
"Giraftron believes all other animals are irrelevant "
712-
"when compared to the glorious majesty of the giraffe."
713-
"\nDaniel: Hello, Girafatron!\nGirafatron:",
714-
max_length=200,
715-
do_sample=True,
716-
top_k=10,
717-
num_return_sequences=1,
718-
eos_token_id=tokenizer.eos_token_id,
719-
)
720-
for seq in sequences:
721-
print(f"Result: {seq['generated_text']}")
722-
723-
@never_test()
724-
def test_falcon_mamba_7b(self):
725-
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b
726-
# https://huggingface.co/tiiuae/falcon-mamba-7b
727-
728-
from transformers import AutoTokenizer
729-
import transformers
730-
import torch
731-
732-
model = "tiiuae/falcon-mamba-7b"
733-
734-
tokenizer = AutoTokenizer.from_pretrained(model)
735-
pipeline = transformers.pipeline(
736-
"text-generation",
737-
model=model,
738-
tokenizer=tokenizer,
739-
dtype=torch.bfloat16,
740-
trust_remote_code=True,
741-
device_map="auto",
742-
)
743-
print()
744-
with steal_forward(pipeline.model):
745-
sequences = pipeline(
746-
"Girafatron is obsessed with giraffes, "
747-
"the most glorious animal on the face of this Earth. "
748-
"Giraftron believes all other animals are irrelevant "
749-
"when compared to the glorious majesty of the giraffe."
750-
"\nDaniel: Hello, Girafatron!\nGirafatron:",
751-
max_length=200,
752-
do_sample=True,
753-
top_k=10,
754-
num_return_sequences=1,
755-
eos_token_id=tokenizer.eos_token_id,
756-
)
757-
for seq in sequences:
758-
print(f"Result: {seq['generated_text']}")
759-
760686
@never_test()
761687
def test_object_detection(self):
762688
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k object_

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,8 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import (
3-
ExtTestCase,
4-
requires_torch,
5-
requires_transformers,
6-
skipif_ci_windows,
7-
ignore_warnings,
8-
hide_stdout,
9-
)
10-
from onnx_diagnostic.helpers import string_type
11-
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
12-
torch_export_patches,
13-
)
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
143

154

165
class TestOnnxExportErrors(ExtTestCase):
17-
@requires_transformers("4.49.999")
18-
@skipif_ci_windows("not working on Windows")
19-
@ignore_warnings(UserWarning)
20-
@hide_stdout()
21-
def test_pytree_flatten_mamba_cache(self):
22-
import torch
23-
import torch.utils._pytree as py_pytree
24-
25-
try:
26-
from transformers.models.mamba.modeling_mamba import MambaCache
27-
except ImportError:
28-
from transformers.cache_utils import MambaCache
29-
30-
class _config:
31-
def __init__(self):
32-
self.intermediate_size = 8
33-
self.state_size = 16
34-
self.conv_kernel = 32
35-
self.num_hidden_layers = 64
36-
self.dtype = torch.float16
37-
38-
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
39-
40-
with torch_export_patches(verbose=1):
41-
values, spec = py_pytree.tree_flatten(cache)
42-
cache2 = py_pytree.tree_unflatten(values, spec)
43-
self.assertEqual(cache.max_batch_size, cache2.max_batch_size)
44-
self.assertEqual(cache.intermediate_size, cache2.intermediate_size)
45-
self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size)
46-
self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size)
47-
self.assertEqualArrayAny(cache.conv_states, cache2.conv_states)
48-
self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states)
49-
50-
@requires_transformers("4.50")
51-
@requires_torch("2.7")
52-
@skipif_ci_windows("not working on Windows")
53-
@ignore_warnings(UserWarning)
54-
@hide_stdout()
55-
def test_exportable_mamba_cache(self):
56-
import torch
57-
from transformers.models.mamba.modeling_mamba import MambaCache
58-
59-
class _config:
60-
def __init__(self):
61-
self.intermediate_size = 8
62-
self.state_size = 16
63-
self.conv_kernel = 32
64-
self.num_hidden_layers = 64
65-
self.dtype = torch.float16
66-
67-
class Model(torch.nn.Module):
68-
def forward(self, x: torch.Tensor, cache: MambaCache):
69-
x1 = cache.ssm_states[0] + x
70-
x2 = cache.conv_states[0][:, :, ::2] + x1
71-
return x2
72-
73-
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
74-
# MambaCache was updated in 4.50
75-
self.assertEqual(
76-
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
77-
string_type(cache),
78-
)
79-
x = torch.ones(2, 8, 16).to(torch.float16)
80-
model = Model()
81-
model(x, cache)
82-
83-
with torch_export_patches(verbose=1, patch_transformers=True):
84-
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
85-
torch.export.export(Model(), (x, cache))
86-
87-
@requires_transformers("4.49.999")
88-
@skipif_ci_windows("not working on Windows")
89-
@ignore_warnings(UserWarning)
90-
def test_exportable_mamba_cache_dynamic(self):
91-
import torch
92-
from transformers.models.mamba.modeling_mamba import MambaCache
93-
94-
class _config:
95-
def __init__(self):
96-
self.intermediate_size = 8
97-
self.state_size = 16
98-
self.conv_kernel = 32
99-
self.num_hidden_layers = 2
100-
self.dtype = torch.float16
101-
102-
class Model(torch.nn.Module):
103-
def forward(self, x: torch.Tensor, cache: MambaCache):
104-
x1 = cache.ssm_states[0] + x
105-
x2 = cache.conv_states[0][:, :, ::2] + x1
106-
return x2
107-
108-
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
109-
self.assertEqual(
110-
string_type(cache),
111-
"MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])",
112-
)
113-
x = torch.ones(2, 8, 16).to(torch.float16)
114-
model = Model()
115-
model(x, cache)
116-
DYN = torch.export.Dim.DYNAMIC
117-
118-
with torch_export_patches(patch_transformers=True):
119-
cache = MambaCache(_config(), max_batch_size=2, device="cpu")
120-
torch.export.export(
121-
Model(),
122-
(x, cache),
123-
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
124-
)
125-
1266
@ignore_warnings(UserWarning)
1277
def test_exportable_dynamic_shapes_constraints(self):
1288
import torch

onnx_diagnostic/export/shape_helper.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
4646
from onnx_diagnostic.helpers.cache_helper import (
4747
make_dynamic_cache,
4848
make_encoder_decoder_cache,
49-
make_mamba_cache,
5049
make_static_cache,
5150
)
5251
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
@@ -84,13 +83,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
8483
],
8584
max_cache_len=15,
8685
),
87-
make_mamba_cache(
88-
[
89-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
90-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
91-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
92-
]
93-
),
9486
]
9587
9688
with torch_export_patches(patch_transformers=True):

0 commit comments

Comments
 (0)