Skip to content

Commit 238dc2b

Browse files
committed
Merge branch 'main' of https://github.com/sdpython/onnx-diagnostic into mb4
2 parents ca60329 + fa7591b commit 238dc2b

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
def reduce_model_config(config: Any) -> Dict[str, Any]:
1515
"""Reduces a model size."""
16-
# FalconMambaConfig: use_mambapy
16+
# Mamba models (e.g. FalconMambaConfig) use use_mambapy instead of num_attention_heads
1717
if hasattr(config, "text_config"):
1818
# The model is probably of mixture of models used only for text.
1919
config = config.text_config
@@ -25,7 +25,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
2525
"hidden_size",
2626
"vocab_size",
2727
)
28-
if config.__class__.__name__ == "FalconMambaConfig":
28+
if hasattr(config, "use_mambapy"):
2929
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
3030
kwargs = dict(
3131
num_hidden_layers=min(config.num_hidden_layers, nhl()),
@@ -54,7 +54,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
5454
return kwargs
5555

5656

57-
def _get_input_falcon_mamba(
57+
def _get_input_mamba(
5858
model: torch.nn.Module,
5959
config: Optional[Any],
6060
dummy_max_token_id: int,
@@ -157,8 +157,8 @@ def get_inputs(
157157
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
158158
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
159159

160-
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
161-
res = _get_input_falcon_mamba(
160+
if config is not None and hasattr(config, "use_mambapy"):
161+
res = _get_input_mamba(
162162
model=model,
163163
config=config,
164164
dummy_max_token_id=dummy_max_token_id,
@@ -343,7 +343,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
343343
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
344344
"hidden_size",
345345
)
346-
if config.__class__.__name__ == "FalconMambaConfig":
346+
if hasattr(config, "use_mambapy"):
347347
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
348348
kwargs = dict(
349349
batch_size=2,

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2332,7 +2332,9 @@ def call_torch_export_custom(
23322332
"custom-fake",
23332333
"custom-tracing",
23342334
}
2335-
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
2335+
assert (
2336+
exporter in available
2337+
), f"Unexpected value for exporter={exporter!r} in {sorted(available)}" # type: ignore[type-var]
23362338
assert "model" in data, f"model is missing from data: {sorted(data)}"
23372339
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
23382340
assert ("-strict" not in exporter) or ("strict" not in exporter_options), (

0 commit comments

Comments
 (0)