Skip to content

Commit bd998d3

Browse files
committed
include params relöated to [BUG] Missing Configuration Attributes in DefaultMambularConfig #209
1 parent 161f6de commit bd998d3

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

mambular/configs/mambular_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class DefaultMambularConfig(BaseConfig):
6464
Whether to use PSCAN for the state-space model.
6565
mamba_version : str, default="mamba-torch"
6666
Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2').
67+
conv_bias : bool, default=False
68+
Whether to use a bias in the 1D convolution before each mamba block
69+
AD_weight_decay: bool = True
70+
Whether to use weight decay als for the A and D matrices in Mamba
71+
BC_layer_norm: bool = False
72+
Whether to use layer norm on the B and C matrices
6773
"""
6874

6975
# Architecture Parameters
@@ -82,6 +88,9 @@ class DefaultMambularConfig(BaseConfig):
8288
dt_init_floor: float = 1e-04
8389
norm: str = "RMSNorm"
8490
activation: Callable = nn.SiLU() # noqa: RUF009
91+
conv_bias: bool = False
92+
AD_weight_decay: bool = True
93+
BC_layer_norm: bool = False
8594

8695
# Embedding Parameters
8796
shuffle_embeddings: bool = False

0 commit comments

Comments
 (0)