Skip to content

Commit e16ba03

Browse files
authored
Merge pull request #75 from basf/models
Models
2 parents 302e739 + 5f0608c commit e16ba03

10 files changed

Lines changed: 133 additions & 42 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
def reglu(x):
7+
a, b = x.chunk(2, dim=-1)
8+
return a * F.relu(b)
9+
10+
11+
class ReGLU(nn.Module):
12+
def forward(self, x):
13+
return reglu(x)
14+
15+
16+
class GLU(nn.Module):
17+
def __init__(self):
18+
super(GLU, self).__init__()
19+
20+
def forward(self, x):
21+
assert x.size(-1) % 2 == 0, "Input dimension must be even"
22+
split_dim = x.size(-1) // 2
23+
return x[..., :split_dim] * torch.sigmoid(x[..., split_dim:])
24+
25+
26+
class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
27+
def __init__(self, *args, activation=F.relu, **kwargs):
28+
super(CustomTransformerEncoderLayer, self).__init__(
29+
*args, activation=activation, **kwargs
30+
)
31+
self.custom_activation = activation
32+
33+
# Check if the activation function is an instance of a GLU variant
34+
if activation in [ReGLU, GLU] or isinstance(activation, (ReGLU, GLU)):
35+
self.linear1 = nn.Linear(
36+
self.linear1.in_features,
37+
self.linear1.out_features * 2,
38+
bias=kwargs.get("bias", True),
39+
)
40+
self.linear2 = nn.Linear(
41+
self.linear2.in_features,
42+
self.linear2.out_features,
43+
bias=kwargs.get("bias", True),
44+
)
45+
46+
def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
47+
src2 = self.self_attn(
48+
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
49+
)[0]
50+
src = src + self.dropout1(src2)
51+
src = self.norm1(src)
52+
53+
# Use the provided activation function
54+
if self.custom_activation in [ReGLU, GLU] or isinstance(
55+
self.custom_activation, (ReGLU, GLU)
56+
):
57+
src2 = self.linear2(self.custom_activation(self.linear1(src)))
58+
else:
59+
src2 = self.linear2(self.custom_activation(self.linear1(src)))
60+
61+
src = src + self.dropout2(src2)
62+
src = self.norm2(src)
63+
return src

mambular/base_models/ft_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
InstanceNorm,
1010
GroupNorm,
1111
)
12+
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
1213
from ..configs.fttransformer_config import DefaultFTTransformerConfig
1314
from .basemodel import BaseModel
1415

@@ -87,7 +88,7 @@ def __init__(
8788
"num_embedding_activation", config.num_embedding_activation
8889
)
8990

90-
encoder_layer = nn.TransformerEncoderLayer(
91+
encoder_layer = CustomTransformerEncoderLayer(
9192
d_model=self.hparams.get("d_model", config.d_model),
9293
nhead=self.hparams.get("n_heads", config.n_heads),
9394
batch_first=True,

mambular/base_models/mambular.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def __init__(
174174
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
175175
)
176176

177+
if self.pooling_method == "cls":
178+
self.use_cls = True
179+
else:
180+
self.use_cls = self.hparams.get("use_cls", config.use_cls)
181+
177182
if self.hparams.get("layer_norm_after_embedding"):
178183
self.embedding_norm = nn.LayerNorm(
179184
self.hparams.get("d_model", config.d_model)
@@ -198,10 +203,13 @@ def forward(self, num_features, cat_features):
198203
Tensor
199204
The output predictions of the model.
200205
"""
201-
batch_size = (
202-
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
203-
)
204-
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
206+
if self.use_cls:
207+
batch_size = (
208+
cat_features[0].size(0)
209+
if cat_features != []
210+
else num_features[0].size(0)
211+
)
212+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
205213

206214
if len(self.cat_embeddings) > 0 and cat_features:
207215
cat_embeddings = [
@@ -225,11 +233,20 @@ def forward(self, num_features, cat_features):
225233
num_embeddings = None
226234

227235
if cat_embeddings is not None and num_embeddings is not None:
228-
x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1)
236+
if self.use_cls:
237+
x = torch.cat([cat_embeddings, num_embeddings, cls_tokens], dim=1)
238+
else:
239+
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
229240
elif cat_embeddings is not None:
230-
x = torch.cat([cls_tokens, cat_embeddings], dim=1)
241+
if self.use_cls:
242+
x = torch.cat([cat_embeddings, cls_tokens], dim=1)
243+
else:
244+
x = cat_embeddings
231245
elif num_embeddings is not None:
232-
x = torch.cat([cls_tokens, num_embeddings], dim=1)
246+
if self.use_cls:
247+
x = torch.cat([num_embeddings, cls_tokens], dim=1)
248+
else:
249+
x = num_embeddings
233250
else:
234251
raise ValueError("No features provided to the model.")
235252

@@ -242,7 +259,9 @@ def forward(self, num_features, cat_features):
242259
elif self.pooling_method == "sum":
243260
x = torch.sum(x, dim=1)
244261
elif self.pooling_method == "cls_token":
245-
x = x[:, 0]
262+
x = x[:, -1]
263+
elif self.pooling_method == "last":
264+
x = x[:, -1]
246265
else:
247266
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
248267

mambular/base_models/tabtransformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
1313
from .basemodel import BaseModel
14+
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
1415

1516

1617
class TabTransformer(BaseModel):
@@ -91,7 +92,7 @@ def __init__(
9192
"num_embedding_activation", config.num_embedding_activation
9293
)
9394

94-
encoder_layer = nn.TransformerEncoderLayer(
95+
encoder_layer = CustomTransformerEncoderLayer(
9596
d_model=self.hparams.get("d_model", config.d_model),
9697
nhead=self.hparams.get("n_heads", config.n_heads),
9798
batch_first=True,

mambular/configs/fttransformer_config.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
import torch.nn as nn
3+
from ..arch_utils.transformer_utils import ReGLU
34

45

56
@dataclass
@@ -63,15 +64,15 @@ class DefaultFTTransformerConfig:
6364
lr_patience: int = 10
6465
weight_decay: float = 1e-06
6566
lr_factor: float = 0.1
66-
d_model: int = 64
67-
n_layers: int = 8
68-
n_heads: int = 4
69-
attn_dropout: float = 0.3
70-
ff_dropout: float = 0.3
71-
norm: str = "RMSNorm"
67+
d_model: int = 128
68+
n_layers: int = 4
69+
n_heads: int = 8
70+
attn_dropout: float = 0.2
71+
ff_dropout: float = 0.1
72+
norm: str = "LayerNorm"
7273
activation: callable = nn.SELU()
7374
num_embedding_activation: callable = nn.Identity()
74-
head_layer_sizes: list = (128, 64, 32)
75+
head_layer_sizes: list = ()
7576
head_dropout: float = 0.5
7677
head_skip_layers: bool = False
7778
head_activation: callable = nn.SELU()
@@ -80,6 +81,7 @@ class DefaultFTTransformerConfig:
8081
pooling_method: str = "cls"
8182
norm_first: bool = False
8283
bias: bool = True
83-
transformer_activation: callable = nn.SELU()
84+
transformer_activation: callable = ReGLU()
8485
layer_norm_eps: float = 1e-05
85-
transformer_dim_feedforward: int = 512
86+
transformer_dim_feedforward: int = 256
87+
numerical_embedding: str = "ple"

mambular/configs/mambular_config.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,32 @@ class DefaultMambularConfig:
6969
Whether to use bidirectional processing of the input sequences.
7070
use_learnable_interaction : bool, default=False
7171
Whether to use learnable feature interactions before passing through mamba blocks.
72+
use_cls : bool, default=True
73+
Whether to append a cls to the beginning of each 'sequence'.
7274
"""
7375

7476
lr: float = 1e-04
7577
lr_patience: int = 10
7678
weight_decay: float = 1e-06
7779
lr_factor: float = 0.1
7880
d_model: int = 64
79-
n_layers: int = 8
81+
n_layers: int = 4
8082
expand_factor: int = 2
8183
bias: bool = False
82-
d_conv: int = 16
84+
d_conv: int = 4
8385
conv_bias: bool = True
84-
dropout: float = 0.05
86+
dropout: float = 0.0
8587
dt_rank: str = "auto"
86-
d_state: int = 32
88+
d_state: int = 128
8789
dt_scale: float = 1.0
8890
dt_init: str = "random"
8991
dt_max: float = 0.1
9092
dt_min: float = 1e-04
9193
dt_init_floor: float = 1e-04
92-
norm: str = "RMSNorm"
93-
activation: callable = nn.SELU()
94+
norm: str = "LayerNorm"
95+
activation: callable = nn.SiLU()
9496
num_embedding_activation: callable = nn.Identity()
95-
head_layer_sizes: list = (128, 64, 32)
97+
head_layer_sizes: list = ()
9698
head_dropout: float = 0.5
9799
head_skip_layers: bool = False
98100
head_activation: callable = nn.SELU()
@@ -101,3 +103,4 @@ class DefaultMambularConfig:
101103
pooling_method: str = "avg"
102104
bidirectional: bool = False
103105
use_learnable_interaction: bool = False
106+
use_cls: bool = False

mambular/configs/mlp_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DefaultMLPConfig:
4141
lr_patience: int = 10
4242
weight_decay: float = 1e-06
4343
lr_factor: float = 0.1
44-
layer_sizes: list = (128, 128, 32)
44+
layer_sizes: list = (256, 128, 32)
4545
activation: callable = nn.SELU()
4646
skip_layers: bool = False
4747
dropout: float = 0.5

mambular/configs/resnet_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class DefaultResNetConfig:
4343
lr_patience: int = 10
4444
weight_decay: float = 1e-06
4545
lr_factor: float = 0.1
46-
layer_sizes: list = (128, 128, 32)
46+
layer_sizes: list = (256, 128, 32)
4747
activation: callable = nn.SELU()
4848
skip_layers: bool = False
4949
dropout: float = 0.5

mambular/configs/tabtransformer_config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
import torch.nn as nn
3+
from ..arch_utils.transformer_utils import ReGLU
34

45

56
@dataclass
@@ -63,15 +64,15 @@ class DefaultTabTransformerConfig:
6364
lr_patience: int = 10
6465
weight_decay: float = 1e-06
6566
lr_factor: float = 0.1
66-
d_model: int = 64
67-
n_layers: int = 8
68-
n_heads: int = 4
69-
attn_dropout: float = 0.3
70-
ff_dropout: float = 0.3
71-
norm: str = "RMSNorm"
67+
d_model: int = 128
68+
n_layers: int = 4
69+
n_heads: int = 8
70+
attn_dropout: float = 0.2
71+
ff_dropout: float = 0.1
72+
norm: str = "LayerNorm"
7273
activation: callable = nn.SELU()
7374
num_embedding_activation: callable = nn.Identity()
74-
head_layer_sizes: list = (128, 64, 32)
75+
head_layer_sizes: list = ()
7576
head_dropout: float = 0.5
7677
head_skip_layers: bool = False
7778
head_activation: callable = nn.SELU()
@@ -80,6 +81,6 @@ class DefaultTabTransformerConfig:
8081
pooling_method: str = "avg"
8182
norm_first: bool = True
8283
bias: bool = True
83-
transformer_activation: callable = nn.SELU()
84+
transformer_activation: callable = ReGLU()
8485
layer_norm_eps: float = 1e-05
8586
transformer_dim_feedforward: int = 512

mambular/preprocessing/preprocessor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def fit(self, X, y=None):
227227
numeric_transformer_steps.append(("scaler", StandardScaler()))
228228

229229
elif self.numerical_preprocessing == "normalization":
230-
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
230+
numeric_transformer_steps.append(
231+
("normalizer", MinMaxScaler(feature_range=(-1, 1)))
232+
)
231233

232234
elif self.numerical_preprocessing == "quantile":
233235
numeric_transformer_steps.append(
@@ -240,12 +242,15 @@ def fit(self, X, y=None):
240242
)
241243

242244
elif self.numerical_preprocessing == "polynomial":
245+
numeric_transformer_steps.append(("scaler", StandardScaler()))
243246
numeric_transformer_steps.append(
244247
(
245248
"polynomial",
246249
PolynomialFeatures(self.degree, include_bias=False),
247250
)
248251
)
252+
# if self.degree > 10:
253+
# numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
249254

250255
elif self.numerical_preprocessing == "splines":
251256
numeric_transformer_steps.append(
@@ -260,13 +265,9 @@ def fit(self, X, y=None):
260265
)
261266

262267
elif self.numerical_preprocessing == "ple":
263-
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
264268
numeric_transformer_steps.append(
265-
("ple", PLE(n_bins=self.n_bins, task=self.task))
269+
("normalizer", MinMaxScaler(feature_range=(-1, 1)))
266270
)
267-
268-
elif self.numerical_preprocessing == "ple":
269-
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
270271
numeric_transformer_steps.append(
271272
("ple", PLE(n_bins=self.n_bins, task=self.task))
272273
)

0 commit comments

Comments
 (0)