Skip to content

Commit 7813141

Browse files
authored
Merge pull request #47 from basf/restructure
Restructure
2 parents aa6b3be + 893b69c commit 7813141

9 files changed

Lines changed: 186 additions & 52 deletions

File tree

mambular/arch_utils/mamba_arch.py

Lines changed: 122 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(
4141
dt_init_floor=1e-04,
4242
norm=RMSNorm,
4343
activation=F.silu,
44+
bidirectional=False,
45+
use_learnable_interaction=False,
4446
):
4547
super().__init__()
4648

@@ -62,6 +64,8 @@ def __init__(
6264
dt_init_floor,
6365
norm,
6466
activation,
67+
bidirectional,
68+
use_learnable_interaction,
6569
)
6670
for _ in range(n_layers)
6771
]
@@ -99,6 +103,8 @@ def __init__(
99103
dt_init_floor=1e-04,
100104
norm=RMSNorm,
101105
activation=F.silu,
106+
bidirectional=False,
107+
use_learnable_interaction=False,
102108
):
103109
super().__init__()
104110

@@ -141,6 +147,8 @@ def __init__(
141147
dt_min=dt_min,
142148
dt_init_floor=dt_init_floor,
143149
activation=activation,
150+
bidirectional=bidirectional,
151+
use_learnable_interaction=use_learnable_interaction,
144152
)
145153
self.norm = norm(d_model)
146154

@@ -153,14 +161,14 @@ class MambaBlock(nn.Module):
153161
"""MambaBlock module containing the main computational components.
154162
155163
Attributes:
156-
config (MambularConfig): Configuration object for the MambaBlock.
157164
in_proj (nn.Linear): Linear projection for input.
158165
conv1d (nn.Conv1d): 1D convolutional layer.
159166
x_proj (nn.Linear): Linear projection for input-dependent tensors.
160167
dt_proj (nn.Linear): Linear projection for dynamical time.
161168
A_log (nn.Parameter): Logarithmically stored A tensor.
162169
D (nn.Parameter): Tensor for D component.
163170
out_proj (nn.Linear): Linear projection for output.
171+
learnable_interaction (LearnableFeatureInteraction): Learnable feature interaction layer.
164172
"""
165173

166174
def __init__(
@@ -179,88 +187,154 @@ def __init__(
179187
dt_min=1e-03,
180188
dt_init_floor=1e-04,
181189
activation=F.silu,
190+
bidirectional=False,
191+
use_learnable_interaction=False,
182192
):
183193
super().__init__()
184194
self.d_inner = d_model * expand_factor
195+
self.bidirectional = bidirectional
196+
self.use_learnable_interaction = use_learnable_interaction
185197

186-
self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
198+
self.in_proj_fwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
199+
if self.bidirectional:
200+
self.in_proj_bwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
187201

188-
self.conv1d = nn.Conv1d(
202+
self.conv1d_fwd = nn.Conv1d(
189203
in_channels=self.d_inner,
190204
out_channels=self.d_inner,
191205
kernel_size=d_conv,
192206
bias=conv_bias,
193207
groups=self.d_inner,
194208
padding=d_conv - 1,
195209
)
210+
if self.bidirectional:
211+
self.conv1d_bwd = nn.Conv1d(
212+
in_channels=self.d_inner,
213+
out_channels=self.d_inner,
214+
kernel_size=d_conv,
215+
bias=conv_bias,
216+
groups=self.d_inner,
217+
padding=d_conv - 1,
218+
)
196219

197220
self.dropout = nn.Dropout(dropout)
198221
self.activation = activation
199222

200-
self.x_proj = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)
223+
if self.use_learnable_interaction:
224+
self.learnable_interaction = LearnableFeatureInteraction(self.d_inner)
225+
226+
self.x_proj_fwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)
227+
if self.bidirectional:
228+
self.x_proj_bwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)
201229

202-
self.dt_proj = nn.Linear(dt_rank, self.d_inner, bias=True)
230+
self.dt_proj_fwd = nn.Linear(dt_rank, self.d_inner, bias=True)
231+
if self.bidirectional:
232+
self.dt_proj_bwd = nn.Linear(dt_rank, self.d_inner, bias=True)
203233

204234
dt_init_std = dt_rank**-0.5 * dt_scale
205235
if dt_init == "constant":
206-
nn.init.constant_(self.dt_proj.weight, dt_init_std)
236+
nn.init.constant_(self.dt_proj_fwd.weight, dt_init_std)
237+
if self.bidirectional:
238+
nn.init.constant_(self.dt_proj_bwd.weight, dt_init_std)
207239
elif dt_init == "random":
208-
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
240+
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
241+
if self.bidirectional:
242+
nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
209243
else:
210244
raise NotImplementedError
211245

212-
dt = torch.exp(
246+
dt_fwd = torch.exp(
213247
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
214248
+ math.log(dt_min)
215249
).clamp(min=dt_init_floor)
216-
inv_dt = dt + torch.log(-torch.expm1(-dt))
250+
inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
217251
with torch.no_grad():
218-
self.dt_proj.bias.copy_(inv_dt)
252+
self.dt_proj_fwd.bias.copy_(inv_dt_fwd)
253+
254+
if self.bidirectional:
255+
dt_bwd = torch.exp(
256+
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
257+
+ math.log(dt_min)
258+
).clamp(min=dt_init_floor)
259+
inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
260+
with torch.no_grad():
261+
self.dt_proj_bwd.bias.copy_(inv_dt_bwd)
219262

220263
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
221-
self.A_log = nn.Parameter(torch.log(A))
222-
self.D = nn.Parameter(torch.ones(self.d_inner))
264+
self.A_log_fwd = nn.Parameter(torch.log(A))
265+
if self.bidirectional:
266+
self.A_log_bwd = nn.Parameter(torch.log(A))
267+
268+
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
269+
if self.bidirectional:
270+
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
271+
223272
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
224273
self.dt_rank = dt_rank
225274
self.d_state = d_state
226275

227276
def forward(self, x):
228277
_, L, _ = x.shape
229278

230-
xz = self.in_proj(x)
231-
x, z = xz.chunk(2, dim=-1)
279+
xz_fwd = self.in_proj_fwd(x)
280+
x_fwd, z_fwd = xz_fwd.chunk(2, dim=-1)
232281

233-
x = x.transpose(1, 2)
234-
x = self.conv1d(x)[:, :, :L]
235-
x = x.transpose(1, 2)
282+
x_fwd = x_fwd.transpose(1, 2)
283+
x_fwd = self.conv1d_fwd(x_fwd)[:, :, :L]
284+
x_fwd = x_fwd.transpose(1, 2)
236285

237-
x = self.activation(x)
238-
x = self.dropout(x)
239-
y = self.ssm(x)
286+
if self.bidirectional:
287+
xz_bwd = self.in_proj_bwd(x)
288+
x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1)
240289

241-
z = self.activation(z)
242-
z = self.dropout(z)
290+
x_bwd = x_bwd.transpose(1, 2)
291+
x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L]
292+
x_bwd = x_bwd.transpose(1, 2)
243293

244-
output = y * z
245-
output = self.out_proj(output)
294+
if self.use_learnable_interaction:
295+
x_fwd = self.learnable_interaction(x_fwd)
296+
if self.bidirectional:
297+
x_bwd = self.learnable_interaction(x_bwd)
246298

247-
return output
299+
x_fwd = self.activation(x_fwd)
300+
x_fwd = self.dropout(x_fwd)
301+
y_fwd = self.ssm(x_fwd, forward=True)
248302

249-
def ssm(self, x):
250-
A = -torch.exp(self.A_log.float())
251-
D = self.D.float()
303+
if self.bidirectional:
304+
x_bwd = self.activation(x_bwd)
305+
x_bwd = self.dropout(x_bwd)
306+
y_bwd = self.ssm(torch.flip(x_bwd, [1]), forward=False)
307+
y = y_fwd + torch.flip(y_bwd, [1])
308+
else:
309+
y = y_fwd
252310

253-
deltaBC = self.x_proj(x)
311+
z_fwd = self.activation(z_fwd)
312+
z_fwd = self.dropout(z_fwd)
254313

255-
delta, B, C = torch.split(
256-
deltaBC,
257-
[self.dt_rank, self.d_state, self.d_state],
258-
dim=-1,
259-
)
260-
delta = F.softplus(self.dt_proj(delta))
314+
output = y * z_fwd
315+
output = self.out_proj(output)
261316

262-
y = self.selective_scan_seq(x, delta, A, B, C, D)
317+
return output
318+
319+
def ssm(self, x, forward=True):
320+
if forward:
321+
A = -torch.exp(self.A_log_fwd.float())
322+
D = self.D_fwd.float()
323+
deltaBC = self.x_proj_fwd(x)
324+
delta, B, C = torch.split(
325+
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
326+
)
327+
delta = F.softplus(self.dt_proj_fwd(delta))
328+
else:
329+
A = -torch.exp(self.A_log_bwd.float())
330+
D = self.D_bwd.float()
331+
deltaBC = self.x_proj_bwd(x)
332+
delta, B, C = torch.split(
333+
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
334+
)
335+
delta = F.softplus(self.dt_proj_bwd(delta))
263336

337+
y = self.selective_scan_seq(x, delta, A, B, C, D)
264338
return y
265339

266340
def selective_scan_seq(self, x, delta, A, B, C, D):
@@ -285,3 +359,15 @@ def selective_scan_seq(self, x, delta, A, B, C, D):
285359
y = y + D * x
286360

287361
return y
362+
363+
364+
class LearnableFeatureInteraction(nn.Module):
365+
def __init__(self, n_vars):
366+
super().__init__()
367+
self.interaction_weights = nn.Parameter(torch.Tensor(n_vars, n_vars))
368+
nn.init.xavier_uniform_(self.interaction_weights)
369+
370+
def forward(self, x):
371+
batch_size, n_vars, d_model = x.size()
372+
interactions = torch.matmul(x, self.interaction_weights)
373+
return interactions.view(batch_size, n_vars, d_model)

mambular/base_models/mambular.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def __init__(
105105
dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
106106
norm=globals()[self.hparams.get("norm", config.norm)],
107107
activation=self.hparams.get("activation", config.activation),
108+
bidirectional=self.hparams.get("bidiretional", config.bidirectional),
109+
use_learnable_interaction=self.hparams.get(
110+
"use_learnable_interactions", config.use_learnable_interaction
111+
),
108112
)
109113

110114
norm_layer = self.hparams.get("norm", config.norm)

mambular/configs/fttransformer_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ class DefaultFTTransformerConfig:
2727
bias: bool = True
2828
transformer_activation: callable = nn.SELU()
2929
layer_norm_eps: float = 1e-05
30-
transformer_dim_feedforward: int = 2048
30+
transformer_dim_feedforward: int = 512

mambular/configs/mambular_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,5 @@ class DefaultMambularConfig:
3232
head_use_batch_norm: bool = False
3333
layer_norm_after_embedding: bool = False
3434
pooling_method: str = "avg"
35+
bidirectional: bool = False
36+
use_learnable_interaction: bool = False

mambular/configs/tabtransformer_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ class DefaultTabTransformerConfig:
2727
bias: bool = True
2828
transformer_activation: callable = nn.SELU()
2929
layer_norm_eps: float = 1e-05
30-
transformer_dim_feedforward: int = 2048
30+
transformer_dim_feedforward: int = 512

mambular/models/sklearn_base_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(self, model, config, **kwargs):
2121
"task",
2222
"cat_cutoff",
2323
"treat_all_integers_as_numerical",
24+
"knots",
25+
"degree",
2426
]
2527

2628
self.config_kwargs = {

mambular/models/sklearn_base_lss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self, model, config, **kwargs):
4343
"task",
4444
"cat_cutoff",
4545
"treat_all_integers_as_numerical",
46+
"knots",
47+
"degree",
4648
]
4749

4850
self.config_kwargs = {

mambular/models/sklearn_base_regressor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def __init__(self, model, config, **kwargs):
2020
"task",
2121
"cat_cutoff",
2222
"treat_all_integers_as_numerical",
23+
"knots",
24+
"degree",
2325
]
2426

2527
self.config_kwargs = {

0 commit comments

Comments
 (0)