Skip to content

Commit 0d4442a

Browse files
authored
Merge pull request #97 from basf/AB_layer
adding AB layernorm and weight decay to Mamba
2 parents 56801dd + 22dad68 commit 0d4442a

3 files changed

Lines changed: 75 additions & 9 deletions

File tree

mambular/arch_utils/mamba_arch.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def __init__(
4343
activation=F.silu,
4444
bidirectional=False,
4545
use_learnable_interaction=False,
46+
layer_norm_eps=1e-05,
47+
AB_weight_decay=False,
48+
AB_layer_norm=True,
4649
):
4750
super().__init__()
4851

@@ -66,6 +69,9 @@ def __init__(
6669
activation,
6770
bidirectional,
6871
use_learnable_interaction,
72+
layer_norm_eps,
73+
AB_weight_decay,
74+
AB_layer_norm,
6975
)
7076
for _ in range(n_layers)
7177
]
@@ -105,6 +111,9 @@ def __init__(
105111
activation=F.silu,
106112
bidirectional=False,
107113
use_learnable_interaction=False,
114+
layer_norm_eps=1e-05,
115+
AB_weight_decay=False,
116+
AB_layer_norm=False,
108117
):
109118
super().__init__()
110119

@@ -149,8 +158,11 @@ def __init__(
149158
activation=activation,
150159
bidirectional=bidirectional,
151160
use_learnable_interaction=use_learnable_interaction,
161+
layer_norm_eps=layer_norm_eps,
162+
AB_weight_decay=AB_weight_decay,
163+
AB_layer_norm=AB_layer_norm,
152164
)
153-
self.norm = norm(d_model)
165+
self.norm = norm(d_model, eps=layer_norm_eps)
154166

155167
def forward(self, x):
156168
output = self.layers(self.norm(x)) + x
@@ -189,6 +201,9 @@ def __init__(
189201
activation=F.silu,
190202
bidirectional=False,
191203
use_learnable_interaction=False,
204+
layer_norm_eps=1e-05,
205+
AB_weight_decay=False,
206+
AB_layer_norm=False,
192207
):
193208
super().__init__()
194209
self.d_inner = d_model * expand_factor
@@ -239,6 +254,7 @@ def __init__(
239254
elif dt_init == "random":
240255
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
241256
if self.bidirectional:
257+
242258
nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
243259
else:
244260
raise NotImplementedError
@@ -262,17 +278,35 @@ def __init__(
262278

263279
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
264280
self.A_log_fwd = nn.Parameter(torch.log(A))
281+
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
282+
265283
if self.bidirectional:
266284
self.A_log_bwd = nn.Parameter(torch.log(A))
285+
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
286+
287+
if not AB_weight_decay:
288+
self.A_log_fwd._no_weight_decay = True
289+
self.D_fwd._no_weight_decay = True
267290

268-
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
269291
if self.bidirectional:
270-
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
292+
293+
if not AB_weight_decay:
294+
self.A_log_bwd._no_weight_decay = True
295+
self.D_bwd._no_weight_decay = True
271296

272297
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
273298
self.dt_rank = dt_rank
274299
self.d_state = d_state
275300

301+
if AB_layer_norm:
302+
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
303+
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
304+
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
305+
else:
306+
self.dt_layernorm = None
307+
self.B_layernorm = None
308+
self.C_layernorm = None
309+
276310
def forward(self, x):
277311
_, L, _ = x.shape
278312

@@ -316,6 +350,15 @@ def forward(self, x):
316350

317351
return output
318352

353+
def _apply_layernorms(self, dt, B, C):
354+
if self.dt_layernorm is not None:
355+
dt = self.dt_layernorm(dt)
356+
if self.B_layernorm is not None:
357+
B = self.B_layernorm(B)
358+
if self.C_layernorm is not None:
359+
C = self.C_layernorm(C)
360+
return dt, B, C
361+
319362
def ssm(self, x, forward=True):
320363
if forward:
321364
A = -torch.exp(self.A_log_fwd.float())
@@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
324367
delta, B, C = torch.split(
325368
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
326369
)
370+
delta, B, C = self._apply_layernorms(delta, B, C)
327371
delta = F.softplus(self.dt_proj_fwd(delta))
328372
else:
329373
A = -torch.exp(self.A_log_bwd.float())
@@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
332376
delta, B, C = torch.split(
333377
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
334378
)
379+
delta, B, C = self._apply_layernorms(delta, B, C)
335380
delta = F.softplus(self.dt_proj_bwd(delta))
336381

337382
y = self.selective_scan_seq(x, delta, A, B, C, D)

mambular/base_models/mambular.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,33 @@ def __init__(
109109
use_learnable_interaction=self.hparams.get(
110110
"use_learnable_interactions", config.use_learnable_interaction
111111
),
112+
AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay),
113+
AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm),
114+
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
112115
)
113-
114116
norm_layer = self.hparams.get("norm", config.norm)
115117
if norm_layer == "RMSNorm":
116-
self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
118+
self.norm_f = RMSNorm(
119+
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
120+
)
117121
elif norm_layer == "LayerNorm":
118-
self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
122+
self.norm_f = LayerNorm(
123+
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
124+
)
119125
elif norm_layer == "BatchNorm":
120-
self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
126+
self.norm_f = BatchNorm(
127+
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
128+
)
121129
elif norm_layer == "InstanceNorm":
122-
self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
130+
self.norm_f = InstanceNorm(
131+
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
132+
)
123133
elif norm_layer == "GroupNorm":
124-
self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
134+
self.norm_f = GroupNorm(
135+
1,
136+
self.hparams.get("d_model", config.d_model),
137+
eps=config.layer_norm_eps,
138+
)
125139
elif norm_layer == "LearnableLayerScaling":
126140
self.norm_f = LearnableLayerScaling(
127141
self.hparams.get("d_model", config.d_model)

mambular/configs/mambular_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class DefaultMambularConfig:
7373
Whether to append a cls to the end of each 'sequence'.
7474
shuffle_embeddings : bool, default=False.
7575
Whether to shuffle the embeddings before being passed to the Mamba layers.
76+
layer_norm_eps : float, default=1e-05
77+
Epsilon value for layer normalization.
78+
AB_weight_decay : bool, default=False
79+
wether weight decay is also applied to A-B matrices
7680
"""
7781

7882
lr: float = 1e-04
@@ -107,3 +111,6 @@ class DefaultMambularConfig:
107111
use_learnable_interaction: bool = False
108112
use_cls: bool = False
109113
shuffle_embeddings: bool = False
114+
layer_norm_eps: float = 1e-05
115+
AB_weight_decay: bool = False
116+
AB_layer_norm: bool = True

0 commit comments

Comments
 (0)