Skip to content

Commit 8933cf7

Browse files
authored
Merge pull request #98 from basf/AB_layer
adjust names of matrices
2 parents 0d4442a + 7c2d343 commit 8933cf7

3 files changed

Lines changed: 21 additions & 19 deletions

File tree

mambular/arch_utils/mamba_arch.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def __init__(
4444
bidirectional=False,
4545
use_learnable_interaction=False,
4646
layer_norm_eps=1e-05,
47-
AB_weight_decay=False,
48-
AB_layer_norm=True,
47+
AD_weight_decay=False,
48+
BC_layer_norm=True,
4949
):
5050
super().__init__()
5151

@@ -70,8 +70,8 @@ def __init__(
7070
bidirectional,
7171
use_learnable_interaction,
7272
layer_norm_eps,
73-
AB_weight_decay,
74-
AB_layer_norm,
73+
AD_weight_decay,
74+
BC_layer_norm,
7575
)
7676
for _ in range(n_layers)
7777
]
@@ -112,8 +112,8 @@ def __init__(
112112
bidirectional=False,
113113
use_learnable_interaction=False,
114114
layer_norm_eps=1e-05,
115-
AB_weight_decay=False,
116-
AB_layer_norm=False,
115+
AD_weight_decay=False,
116+
BC_layer_norm=False,
117117
):
118118
super().__init__()
119119

@@ -159,8 +159,8 @@ def __init__(
159159
bidirectional=bidirectional,
160160
use_learnable_interaction=use_learnable_interaction,
161161
layer_norm_eps=layer_norm_eps,
162-
AB_weight_decay=AB_weight_decay,
163-
AB_layer_norm=AB_layer_norm,
162+
AD_weight_decay=AD_weight_decay,
163+
BC_layer_norm=BC_layer_norm,
164164
)
165165
self.norm = norm(d_model, eps=layer_norm_eps)
166166

@@ -202,8 +202,8 @@ def __init__(
202202
bidirectional=False,
203203
use_learnable_interaction=False,
204204
layer_norm_eps=1e-05,
205-
AB_weight_decay=False,
206-
AB_layer_norm=False,
205+
AD_weight_decay=False,
206+
BC_layer_norm=False,
207207
):
208208
super().__init__()
209209
self.d_inner = d_model * expand_factor
@@ -284,21 +284,21 @@ def __init__(
284284
self.A_log_bwd = nn.Parameter(torch.log(A))
285285
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
286286

287-
if not AB_weight_decay:
287+
if not AD_weight_decay:
288288
self.A_log_fwd._no_weight_decay = True
289289
self.D_fwd._no_weight_decay = True
290290

291291
if self.bidirectional:
292292

293-
if not AB_weight_decay:
293+
if not AD_weight_decay:
294294
self.A_log_bwd._no_weight_decay = True
295295
self.D_bwd._no_weight_decay = True
296296

297297
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
298298
self.dt_rank = dt_rank
299299
self.d_state = d_state
300300

301-
if AB_layer_norm:
301+
if BC_layer_norm:
302302
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
303303
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
304304
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)

mambular/base_models/mambular.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(
109109
use_learnable_interaction=self.hparams.get(
110110
"use_learnable_interactions", config.use_learnable_interaction
111111
),
112-
AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay),
113-
AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm),
112+
AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay),
113+
BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm),
114114
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
115115
)
116116
norm_layer = self.hparams.get("norm", config.norm)

mambular/configs/mambular_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ class DefaultMambularConfig:
7575
Whether to shuffle the embeddings before being passed to the Mamba layers.
7676
layer_norm_eps : float, default=1e-05
7777
Epsilon value for layer normalization.
78-
AB_weight_decay : bool, default=False
79-
wether weight decay is also applied to A-B matrices
78+
AD_weight_decay : bool, default=False
79+
whether weight decay is also applied to A-D matrices
80+
BC_layer_norm: bool, default=True
81+
whether to apply layer normalization to B-C matrices
8082
"""
8183

8284
lr: float = 1e-04
@@ -112,5 +114,5 @@ class DefaultMambularConfig:
112114
use_cls: bool = False
113115
shuffle_embeddings: bool = False
114116
layer_norm_eps: float = 1e-05
115-
AB_weight_decay: bool = False
116-
AB_layer_norm: bool = True
117+
AD_weight_decay: bool = False
118+
BC_layer_norm: bool = True

0 commit comments

Comments
 (0)