1313
1414def reduce_model_config (config : Any ) -> Dict [str , Any ]:
1515 """Reduces a model size."""
16- # FalconMambaConfig: use_mambapy
16+ # Mamba models (e.g. FalconMambaConfig) use use_mambapy instead of num_attention_heads
1717 if hasattr (config , "text_config" ):
1818 # The model is probably of mixture of models used only for text.
1919 config = config .text_config
@@ -25,7 +25,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
2525 "hidden_size" ,
2626 "vocab_size" ,
2727 )
28- if config . __class__ . __name__ == "FalconMambaConfig" :
28+ if hasattr ( config , "use_mambapy" ) :
2929 check_hasattr (config , "conv_kernel" , "state_size" , "intermediate_size" ) # 4 and 8
3030 kwargs = dict (
3131 num_hidden_layers = min (config .num_hidden_layers , nhl ()),
@@ -54,7 +54,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
5454 return kwargs
5555
5656
57- def _get_input_falcon_mamba (
57+ def _get_input_mamba (
5858 model : torch .nn .Module ,
5959 config : Optional [Any ],
6060 dummy_max_token_id : int ,
@@ -157,8 +157,8 @@ def get_inputs(
157157 seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
158158 cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
159159
160- if config is not None and config . __class__ . __name__ == "FalconMambaConfig" :
161- res = _get_input_falcon_mamba (
160+ if config is not None and hasattr ( config , "use_mambapy" ) :
161+ res = _get_input_mamba (
162162 model = model ,
163163 config = config ,
164164 dummy_max_token_id = dummy_max_token_id ,
@@ -343,7 +343,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
343343 ("num_key_value_heads" , "num_attention_heads" , "use_mambapy" ),
344344 "hidden_size" ,
345345 )
346- if config . __class__ . __name__ == "FalconMambaConfig" :
346+ if hasattr ( config , "use_mambapy" ) :
347347 check_hasattr (config , "conv_kernel" , "state_size" , "intermediate_size" ) # 4 and 8
348348 kwargs = dict (
349349 batch_size = 2 ,
0 commit comments