Skip to content

Commit 628182f

Browse files
authored
Merge pull request #237 from basf/refactorization
Refactorization and introduction of AutoInt and Trompt
2 parents b6408bf + d4fbeb5 commit 628182f

55 files changed

Lines changed: 2051 additions & 894 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ Mambular is a Python package that brings the power of advanced deep learning arc
7676
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
7777
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
7878
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
79-
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
79+
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
80+
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
81+
| `Trompt ` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
82+
8083

8184

8285

@@ -211,13 +214,13 @@ random_search.fit(X, y, **fit_params)
211214
print("Best Parameters:", random_search.best_params_)
212215
print("Best Score:", random_search.best_score_)
213216
```
214-
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
217+
Note, that using this, you can also optimize the preprocessing. Just specify the necessary parameters when specifying the preprocessor arguments you want to optimize:
215218
```python
216219
param_dist = {
217220
'd_model': randint(32, 128),
218221
'n_layers': randint(2, 10),
219222
'lr': uniform(1e-5, 1e-3),
220-
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
223+
"numerical_preprocessing": ["ple", "standardization", "box-cox"]
221224
}
222225

223226
```
@@ -321,7 +324,7 @@ Here's how you can implement a custom model with Mambular:
321324
Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
322325

323326
```python
324-
from mambular.base_models import BaseModel
327+
from mambular.base_models.utils import BaseModel
325328
from mambular.utils.get_feature_dimensions import get_feature_dimensions
326329
import torch
327330
import torch.nn
@@ -365,7 +368,7 @@ Here's how you can implement a custom model with Mambular:
365368
You can build a regression, classification, or distributional regression model that can leverage all of Mambular's built-in methods by using the following:
366369

367370
```python
368-
from mambular.models import SklearnBaseRegressor
371+
from mambular.models.utils import SklearnBaseRegressor
369372

370373
class MyRegressor(SklearnBaseRegressor):
371374
def __init__(self, **kwargs):

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def forward(self, num_features, cat_features, emb_features):
156156
# Process categorical embeddings
157157
if self.cat_embeddings and cat_features is not None:
158158
cat_embeddings = [
159-
emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
159+
(
160+
emb(cat_features[i])
161+
if emb(cat_features[i]).ndim == 3
162+
else emb(cat_features[i]).unsqueeze(1)
163+
)
160164
for i, emb in enumerate(self.cat_embeddings)
161165
]
162166

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class ImportanceGetter(nn.Module): # Figure 3 part 1
6+
def __init__(self, P, C, d):
7+
super().__init__()
8+
self.colemb = nn.Parameter(torch.empty(C, d))
9+
self.pemb = nn.Parameter(torch.empty(P, d))
10+
torch.nn.init.normal_(self.colemb, std=0.01)
11+
torch.nn.init.normal_(self.pemb, std=0.01)
12+
self.C = C
13+
self.P = P
14+
self.d = d
15+
self.dense = nn.Linear(2 * self.d, self.d)
16+
self.laynorm1 = nn.LayerNorm(self.d)
17+
self.laynorm2 = nn.LayerNorm(self.d)
18+
19+
def forward(self, O):
20+
eprompt = self.pemb.unsqueeze(0).repeat(O.shape[0], 1, 1)
21+
22+
dense_out = self.dense(torch.cat((self.laynorm1(eprompt), O), dim=-1))
23+
24+
dense_out = dense_out + eprompt + O
25+
26+
ecolumn = self.laynorm2(self.colemb.unsqueeze(0).repeat(O.shape[0], 1, 1))
27+
28+
return torch.softmax(dense_out @ ecolumn.transpose(1, 2), dim=-1)

mambular/arch_utils/mamba_utils/mamba_arch.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@ def __init__(
4343
norm=get_normalization_layer(config), # type: ignore
4444
activation=getattr(config, "activation", nn.SiLU()),
4545
bidirectional=getattr(config, "bidirectional", False),
46-
use_learnable_interaction=getattr(config, "use_learnable_interaction", False),
46+
use_learnable_interaction=getattr(
47+
config, "use_learnable_interaction", False
48+
),
4749
layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5),
4850
AD_weight_decay=getattr(config, "AD_weight_decay", True),
4951
BC_layer_norm=getattr(config, "BC_layer_norm", False),
5052
use_pscan=getattr(config, "use_pscan", False),
53+
dilation=getattr(config, "dilation", 1),
5154
)
5255
for _ in range(getattr(config, "n_layers", 6))
5356
]
@@ -149,6 +152,7 @@ def __init__(
149152
AD_weight_decay=False,
150153
BC_layer_norm=False,
151154
use_pscan=False,
155+
dilation=1,
152156
):
153157
super().__init__()
154158

@@ -194,6 +198,7 @@ def __init__(
194198
AD_weight_decay=AD_weight_decay,
195199
BC_layer_norm=BC_layer_norm,
196200
use_pscan=use_pscan,
201+
dilation=dilation,
197202
)
198203
self.norm = norm
199204

@@ -307,6 +312,7 @@ def __init__(
307312
AD_weight_decay=False,
308313
BC_layer_norm=False,
309314
use_pscan=False,
315+
dilation=1,
310316
):
311317
super().__init__()
312318

@@ -319,7 +325,10 @@ def __init__(
319325
self.pscan = pscan # Store the imported pscan function
320326
except ImportError:
321327
self.pscan = None # Set to None if pscan is not available
322-
print("The 'mambapy' package is not installed. Please install it by running:\n" "pip install mambapy")
328+
print(
329+
"The 'mambapy' package is not installed. Please install it by running:\n"
330+
"pip install mambapy"
331+
)
323332
else:
324333
self.pscan = None
325334

@@ -347,6 +356,7 @@ def __init__(
347356
bias=conv_bias,
348357
groups=self.d_inner,
349358
padding=d_conv - 1,
359+
dilation=dilation,
350360
)
351361

352362
self.dropout = nn.Dropout(dropout)
@@ -375,16 +385,18 @@ def __init__(
375385
else:
376386
raise NotImplementedError
377387

378-
dt_fwd = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp(
379-
min=dt_init_floor
380-
)
388+
dt_fwd = torch.exp(
389+
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
390+
+ math.log(dt_min)
391+
).clamp(min=dt_init_floor)
381392
inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
382393
with torch.no_grad():
383394
self.dt_proj_fwd.bias.copy_(inv_dt_fwd)
384395

385396
if self.bidirectional:
386397
dt_bwd = torch.exp(
387-
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
398+
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
399+
+ math.log(dt_min)
388400
).clamp(min=dt_init_floor)
389401
inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
390402
with torch.no_grad():

mambular/arch_utils/rnn_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, config):
2121
self.rnn_activation = getattr(config, "rnn_activation", "relu")
2222
self.d_conv = getattr(config, "d_conv", 4)
2323
self.residuals = getattr(config, "residuals", False)
24+
self.dilation = getattr(config, "dilation", 1)
2425

2526
# Choose RNN layer based on model_type
2627
rnn_layer = {
@@ -37,7 +38,10 @@ def __init__(self, config):
3738

3839
if self.residuals:
3940
self.residual_matrix = nn.ParameterList(
40-
[nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) for _ in range(self.num_layers)]
41+
[
42+
nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
43+
for _ in range(self.num_layers)
44+
]
4145
)
4246

4347
# First Conv1d layer uses input_size
@@ -49,6 +53,7 @@ def __init__(self, config):
4953
padding=self.d_conv - 1,
5054
bias=self.conv_bias,
5155
groups=self.input_size,
56+
dilation=self.dilation,
5257
)
5358
)
5459
self.layernorms_conv.append(nn.LayerNorm(self.input_size))
@@ -63,6 +68,7 @@ def __init__(self, config):
6368
padding=self.d_conv - 1,
6469
bias=self.conv_bias,
6570
groups=self.hidden_size,
71+
dilation=self.dilation,
6672
)
6773
)
6874
self.layernorms_conv.append(nn.LayerNorm(self.hidden_size))
@@ -159,7 +165,10 @@ def __init__(
159165

160166
if self.residuals:
161167
self.residual_matrix = nn.ParameterList(
162-
[nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) for _ in range(self.num_layers)]
168+
[
169+
nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
170+
for _ in range(self.num_layers)
171+
]
163172
)
164173

165174
# First Conv1d layer uses input_size
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch.nn as nn
2+
import torch
3+
from .layer_utils.embedding_layer import EmbeddingLayer
4+
from .layer_utils.importance import ImportanceGetter
5+
import numpy as np
6+
7+
8+
class Expander(nn.Module): # Figure 3 part 3
9+
def __init__(self, P):
10+
super().__init__()
11+
self.lin = nn.Linear(1, P)
12+
self.relu = nn.ReLU()
13+
self.gn = nn.GroupNorm(2, P)
14+
15+
def forward(self, x):
16+
res = self.relu(self.lin(x.unsqueeze(-1)))
17+
18+
return x.unsqueeze(1) + self.gn(torch.permute(res, (0, 3, 1, 2)))
19+
20+
21+
class TromptCell(nn.Module):
22+
def __init__(self, feature_information, config):
23+
super().__init__()
24+
C = np.sum([len(info) for info in feature_information])
25+
self.enc = EmbeddingLayer(
26+
*feature_information,
27+
config=config,
28+
)
29+
self.fe = ImportanceGetter(config.P, C, config.d_model)
30+
self.ex = Expander(config.P)
31+
32+
def forward(self, *data, O=None):
33+
x_res = self.ex(self.enc(*data))
34+
35+
M = self.fe(O)
36+
37+
return (M.unsqueeze(-1) * x_res).sum(dim=2)
38+
39+
40+
class TromptDecoder(nn.Module):
41+
def __init__(self, d, d_out):
42+
super().__init__()
43+
self.l1 = nn.Linear(d, 1)
44+
self.l2 = nn.Linear(d, d)
45+
self.relu = nn.ReLU()
46+
self.laynorm1 = nn.LayerNorm(d)
47+
self.lf = nn.Linear(d, d_out)
48+
49+
def forward(self, x):
50+
pw = torch.softmax(self.l1(x).squeeze(-1), dim=-1)
51+
52+
xnew = (pw.unsqueeze(-1) * x).sum(dim=-2)
53+
54+
return self.lf(self.laynorm1(self.relu(self.l2(xnew))))

mambular/base_models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from .basemodel import BaseModel
21
from .ft_transformer import FTTransformer
3-
from .lightning_wrapper import TaskModel
42
from .mambatab import MambaTab
53
from .mambattn import MambAttention
64
from .mambular import Mambular
@@ -12,13 +10,16 @@
1210
from .tabm import TabM
1311
from .tabtransformer import TabTransformer
1412
from .tabularnn import TabulaRNN
13+
from .autoint import AutoInt
14+
from .trompt import Trompt
1515

1616
__all__ = [
17+
"Trompt",
18+
"AutoInt",
1719
"MLP",
1820
"NDTF",
1921
"NODE",
2022
"SAINT",
21-
"BaseModel",
2223
"FTTransformer",
2324
"MambAttention",
2425
"MambaTab",
@@ -27,5 +28,4 @@
2728
"TabM",
2829
"TabTransformer",
2930
"TabulaRNN",
30-
"TaskModel",
3131
]

0 commit comments

Comments
 (0)