Skip to content

Commit 79718f7

Browse files
committed
include trompt
1 parent 2bba259 commit 79718f7

9 files changed

Lines changed: 238 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Mambular is a Python package that brings the power of advanced deep learning arc
7878
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
7979
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
8080
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
81+
| `Trompt ` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
8182

8283

8384

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class ImportanceGetter(nn.Module): # Figure 3 part 1
6+
def __init__(self, P, C, d):
7+
super().__init__()
8+
self.colemb = nn.Parameter(torch.empty(C, d))
9+
self.pemb = nn.Parameter(torch.empty(P, d))
10+
torch.nn.init.normal_(self.colemb, std=0.01)
11+
torch.nn.init.normal_(self.pemb, std=0.01)
12+
self.C = C
13+
self.P = P
14+
self.d = d
15+
self.dense = nn.Linear(2 * self.d, self.d)
16+
self.laynorm1 = nn.LayerNorm(self.d)
17+
self.laynorm2 = nn.LayerNorm(self.d)
18+
19+
def forward(self, O):
20+
eprompt = self.pemb.unsqueeze(0).repeat(O.shape[0], 1, 1)
21+
22+
dense_out = self.dense(torch.cat((self.laynorm1(eprompt), O), dim=-1))
23+
24+
dense_out = dense_out + eprompt + O
25+
26+
ecolumn = self.laynorm2(self.colemb.unsqueeze(0).repeat(O.shape[0], 1, 1))
27+
28+
return torch.softmax(dense_out @ ecolumn.transpose(1, 2), dim=-1)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch.nn as nn
2+
import torch
3+
from .layer_utils.embedding_layer import EmbeddingLayer
4+
from .layer_utils.importance import ImportanceGetter
5+
import numpy as np
6+
7+
8+
class Expander(nn.Module): # Figure 3 part 3
9+
def __init__(self, P):
10+
super().__init__()
11+
self.lin = nn.Linear(1, P)
12+
self.relu = nn.ReLU()
13+
self.gn = nn.GroupNorm(2, P)
14+
15+
def forward(self, x):
16+
res = self.relu(self.lin(x.unsqueeze(-1)))
17+
18+
return x.unsqueeze(1) + self.gn(torch.permute(res, (0, 3, 1, 2)))
19+
20+
21+
class TromptCell(nn.Module):
22+
def __init__(self, feature_information, config):
23+
super().__init__()
24+
C = np.sum([len(info) for info in feature_information])
25+
self.enc = EmbeddingLayer(
26+
*feature_information,
27+
config=config,
28+
)
29+
self.fe = ImportanceGetter(config.P, C, config.d_model)
30+
self.ex = Expander(config.P)
31+
32+
def forward(self, *data, O=None):
33+
x_res = self.ex(self.enc(*data))
34+
35+
M = self.fe(O)
36+
37+
return (M.unsqueeze(-1) * x_res).sum(dim=2)
38+
39+
40+
class TromptDecoder(nn.Module):
41+
def __init__(self, d, d_out):
42+
super().__init__()
43+
self.l1 = nn.Linear(d, 1)
44+
self.l2 = nn.Linear(d, d)
45+
self.relu = nn.ReLU()
46+
self.laynorm1 = nn.LayerNorm(d)
47+
self.lf = nn.Linear(d, d_out)
48+
49+
def forward(self, x):
50+
pw = torch.softmax(self.l1(x).squeeze(-1), dim=-1)
51+
52+
xnew = (pw.unsqueeze(-1) * x).sum(dim=-2)
53+
54+
return self.lf(self.laynorm1(self.relu(self.l2(xnew))))

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from .tabtransformer import TabTransformer
1212
from .tabularnn import TabulaRNN
1313
from .autoint import AutoInt
14+
from .trompt import Trompt
1415

1516
__all__ = [
17+
"Trompt",
1618
"AutoInt",
1719
"MLP",
1820
"NDTF",

mambular/base_models/trompt.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch.nn as nn
2+
import torch
3+
from ..arch_utils.get_norm_fn import get_normalization_layer
4+
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
5+
from ..configs.trompt_config import DefaultTromptConfig
6+
from .utils.basemodel import BaseModel
7+
from ..arch_utils.trompt_utils import TromptCell, TromptDecoder
8+
import numpy as np
9+
10+
11+
class Trompt(BaseModel):
12+
13+
def __init__(
14+
self,
15+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
16+
num_classes=1,
17+
config: DefaultTromptConfig = DefaultTromptConfig(), # noqa: B008
18+
**kwargs,
19+
):
20+
super().__init__(config=config, **kwargs)
21+
self.save_hyperparameters(ignore=["feature_information"])
22+
self.returns_ensemble = True
23+
24+
# embedding layer
25+
self.cells = nn.ModuleList(
26+
TromptCell(feature_information, config) for _ in range(config.n_cycles)
27+
)
28+
self.decoder = TromptDecoder(config.d_model, num_classes)
29+
self.init_rec = nn.Parameter(torch.empty(config.P, config.d_model))
30+
self.n_cycles = config.n_cycles
31+
32+
def forward(self, *data):
33+
"""Defines the forward pass of the model.
34+
35+
Parameters
36+
----------
37+
data : tuple
38+
Input tuple of tensors of num_features, cat_features, embeddings.
39+
40+
Returns
41+
-------
42+
Tensor
43+
The output predictions of the model.
44+
"""
45+
O = self.init_rec.unsqueeze(0).repeat(data[0][0].shape[0], 1, 1)
46+
outputs = []
47+
48+
for i in range(self.n_cycles):
49+
O = self.cells[i](*data, O=O)
50+
# print(O.shape)
51+
# print(self.tdown(O).shape)
52+
outputs.append(self.decoder(O))
53+
54+
out = torch.stack(outputs, dim=1).squeeze(-1)
55+
# preds = out.mean(dim=1)
56+
return out

mambular/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from .tabtransformer_config import DefaultTabTransformerConfig
1212
from .tabularnn_config import DefaultTabulaRNNConfig
1313
from .autoint_config import DefaultAutoIntConfig
14+
from .trompt_config import DefaultTromptConfig
1415
from .base_config import BaseConfig
1516

1617
__all__ = [
18+
"DefaultTromptConfig",
1719
"DefaultAutoIntConfig",
1820
"DefaultFTTransformerConfig",
1921
"DefaultMLPConfig",

mambular/configs/trompt_config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from collections.abc import Callable
2+
from dataclasses import dataclass, field
3+
import torch.nn as nn
4+
from ..arch_utils.transformer_utils import ReGLU
5+
from .base_config import BaseConfig
6+
7+
8+
@dataclass
9+
class DefaultTromptConfig(BaseConfig):
10+
"""Configuration class for the Trompt model with predefined hyperparameters.
11+
12+
Parameters
13+
----------
14+
d_model : int, default=128
15+
Dimensionality of the transformer model.
16+
n_cycles : int, default=6
17+
Number of cycles in the Trompt model.
18+
n_cells : int, default=4
19+
Number of cells in each cycle.
20+
P : int, default=128
21+
Number of steps in the Trompt model.
22+
"""
23+
24+
d_model: int = 128
25+
n_cycles: int = 6
26+
n_cells: int = 4
27+
P: int = 128

mambular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
)
2727
from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor
2828
from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor
29+
from .trompt import TromptClassifier, TromptLSS, TromptRegressor
2930

3031
__all__ = [
32+
"TromptClassifier",
33+
"TromptLSS",
34+
"TromptRegressor",
3135
"AutoIntClassifier",
3236
"AutoIntLSS",
3337
"AutoIntRegressor",

mambular/models/trompt.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from ..base_models.trompt import Trompt
2+
from ..configs.trompt_config import DefaultTromptConfig
3+
from ..utils.docstring_generator import generate_docstring
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
7+
8+
9+
class TromptRegressor(SklearnBaseRegressor):
10+
__doc__ = generate_docstring(
11+
DefaultTromptConfig,
12+
model_description="""
13+
Trompt regressor. This class extends the SklearnBaseRegressor
14+
class and uses the Trompt model with the default Trompt
15+
configuration.
16+
""",
17+
examples="""
18+
>>> from mambular.models import TromptRegressor
19+
>>> model = TromptRegressor(d_model=64, n_layers=8)
20+
>>> model.fit(X_train, y_train)
21+
>>> preds = model.predict(X_test)
22+
>>> model.evaluate(X_test, y_test)
23+
""",
24+
)
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(model=Trompt, config=DefaultTromptConfig, **kwargs)
28+
29+
30+
class TromptClassifier(SklearnBaseClassifier):
31+
__doc__ = generate_docstring(
32+
DefaultTromptConfig,
33+
"""Trompt Classifier. This class extends the SklearnBaseClassifier class
34+
and uses the Trompt model with the default Trompt configuration.""",
35+
examples="""
36+
>>> from mambular.models import TromptClassifier
37+
>>> model = TromptClassifier(d_model=64, n_layers=8)
38+
>>> model.fit(X_train, y_train)
39+
>>> preds = model.predict(X_test)
40+
>>> model.evaluate(X_test, y_test)
41+
""",
42+
)
43+
44+
def __init__(self, **kwargs):
45+
super().__init__(model=Trompt, config=DefaultTromptConfig, **kwargs)
46+
47+
48+
class TromptLSS(SklearnBaseLSS):
49+
__doc__ = generate_docstring(
50+
DefaultTromptConfig,
51+
"""Trompt for distributional regression.
52+
This class extends the SklearnBaseLSS class and uses the
53+
Trompt model with the default Trompt configuration.""",
54+
examples="""
55+
>>> from mambular.models import TromptLSS
56+
>>> model = TromptLSS(d_model=64, n_layers=8)
57+
>>> model.fit(X_train, y_train, family="normal")
58+
>>> preds = model.predict(X_test)
59+
>>> model.evaluate(X_test, y_test)
60+
""",
61+
)
62+
63+
def __init__(self, **kwargs):
64+
super().__init__(model=Trompt, config=DefaultTromptConfig, **kwargs)

0 commit comments

Comments
 (0)