Skip to content

Commit 2bba259

Browse files
committed
add AutoInt model class
1 parent d096518 commit 2bba259

7 files changed

Lines changed: 312 additions & 2 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ Mambular is a Python package that brings the power of advanced deep learning arc
7676
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
7777
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
7878
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
79-
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
79+
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
80+
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
81+
8082

8183

8284

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from .tabm import TabM
1111
from .tabtransformer import TabTransformer
1212
from .tabularnn import TabulaRNN
13+
from .autoint import AutoInt
1314

1415
__all__ = [
16+
"AutoInt",
1517
"MLP",
1618
"NDTF",
1719
"NODE",

mambular/base_models/autoint.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import torch.nn as nn
2+
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
3+
from .utils.basemodel import BaseModel
4+
import torch.nn.init as nn_init
5+
import numpy as np
6+
from ..configs.autoint_config import DefaultAutoIntConfig
7+
8+
9+
class AutoInt(BaseModel):
10+
"""
11+
AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks.
12+
13+
This model uses multi-head self-attention layers to learn feature interactions for tabular data.
14+
It supports key-value compression for memory efficiency and is compatible with embedding-based
15+
feature encodings.
16+
17+
Parameters
18+
----------
19+
feature_information : tuple
20+
A tuple containing information about numerical features, categorical features,
21+
and any additional embeddings. Expected format: `(num_feature_info, cat_feature_info, embedding_feature_info)`.
22+
num_classes : int, default=1
23+
Number of output classes. For regression, this should be set to `1`.
24+
config : DefaultAutoIntConfig, optional
25+
Configuration object containing hyperparameters such as `d_model`, `n_heads`, `n_layers`,
26+
dropout rates, and compression settings.
27+
**kwargs : dict
28+
Additional arguments passed to the `BaseModel`.
29+
30+
Attributes
31+
----------
32+
embedding_layer : EmbeddingLayer
33+
Module that processes numerical and categorical features into embeddings.
34+
kv_compression : float or None
35+
The proportion of key-value compression. If `None`, no compression is applied.
36+
kv_compression_sharing : str or None
37+
Defines how key-value compression is shared across layers. Options:
38+
- `"layerwise"`: One shared compression layer for all layers.
39+
- `"headwise"`: Separate key compression per head.
40+
- `"key-value"`: Separate compression layers for `k` and `v`.
41+
shared_kv_compression : nn.Linear or None
42+
Shared key-value compression layer, used when `kv_compression_sharing="layerwise"`.
43+
layers : nn.ModuleList
44+
A list of transformer-based attention layers, each consisting of:
45+
- `attention`: Multi-head self-attention module.
46+
- `linear`: Fully connected layer for projection.
47+
- `norm0`: Layer normalization.
48+
last_norm : nn.LayerNorm or None
49+
Final normalization layer applied before output if `prenormalization` is enabled.
50+
head : nn.Linear
51+
Output layer mapping from the processed feature representation to the final predictions.
52+
"""
53+
54+
def __init__(
55+
self,
56+
feature_information: tuple, # (num_feature_info, cat_feature_info, embedding_feature_info)
57+
num_classes=1,
58+
config: DefaultAutoIntConfig = DefaultAutoIntConfig(), # noqa: B008
59+
**kwargs,
60+
):
61+
super().__init__(config=config, **kwargs)
62+
self.save_hyperparameters(ignore=["feature_information"])
63+
self.returns_ensemble = False
64+
65+
# Embedding layer
66+
self.embedding_layer = EmbeddingLayer(*feature_information, config=config)
67+
n_inputs = np.sum([len(info) for info in feature_information])
68+
69+
# Key-Value Compression
70+
self.kv_compression = config.kv_compression
71+
self.kv_compression_sharing = config.kv_compression_sharing
72+
73+
def make_kv_compression():
74+
compression = nn.Linear(
75+
n_inputs,
76+
int(n_inputs * config.kv_compression),
77+
bias=False,
78+
)
79+
nn_init.xavier_uniform_(compression.weight)
80+
return compression
81+
82+
self.shared_kv_compression = (
83+
make_kv_compression()
84+
if self.kv_compression and self.kv_compression_sharing == "layerwise"
85+
else None
86+
)
87+
88+
# Transformer-based Interaction Layers
89+
self.layers = nn.ModuleList()
90+
for layer_idx in range(config.n_layers):
91+
layer = nn.ModuleDict(
92+
{
93+
"attention": nn.MultiheadAttention(
94+
embed_dim=config.d_model,
95+
num_heads=config.n_heads,
96+
dropout=config.attn_dropout,
97+
batch_first=True,
98+
),
99+
"linear": nn.Linear(config.d_model, config.d_model, bias=False),
100+
"norm0": nn.LayerNorm(config.d_model),
101+
}
102+
)
103+
104+
if self.kv_compression and self.shared_kv_compression is None:
105+
layer["key_compression"] = make_kv_compression()
106+
if self.kv_compression_sharing == "headwise":
107+
layer["value_compression"] = make_kv_compression()
108+
else:
109+
assert self.kv_compression_sharing == "key-value"
110+
111+
self.layers.append(layer)
112+
113+
# Final Normalization & Output Head
114+
self.last_norm = (
115+
nn.LayerNorm(config.d_model) if getattr(config, "prenorm", False) else None
116+
)
117+
118+
self.head = nn.Linear(config.d_model * n_inputs, num_classes)
119+
120+
def _get_kv_compressions(self, layer):
121+
"""
122+
Returns the correct key-value compression layers based on the sharing strategy.
123+
124+
Parameters
125+
----------
126+
layer : nn.ModuleDict
127+
The transformer layer containing possible key-value compression modules.
128+
129+
Returns
130+
-------
131+
tuple of (nn.Linear or None, nn.Linear or None)
132+
The key compression and value compression layers, or `(None, None)` if no compression is applied.
133+
"""
134+
return (
135+
(self.shared_kv_compression, self.shared_kv_compression)
136+
if self.shared_kv_compression is not None
137+
else (
138+
(layer["key_compression"], layer["value_compression"])
139+
if "key_compression" in layer and "value_compression" in layer
140+
else (
141+
(layer["key_compression"], layer["key_compression"])
142+
if "key_compression" in layer
143+
else (None, None)
144+
)
145+
)
146+
)
147+
148+
def forward(self, *data):
149+
"""
150+
Forward pass of the AutoInt model.
151+
152+
Parameters
153+
----------
154+
*data : tuple
155+
Input tuple of tensors containing numerical features, categorical features, and embeddings.
156+
157+
Returns
158+
-------
159+
Tensor
160+
The output predictions of the model.
161+
"""
162+
x = self.embedding_layer(*data) # Shape: (N, J, d_model)
163+
164+
for layer in self.layers:
165+
x_residual = x # Store original input for residual connection
166+
167+
# Apply normalization before attention if prenormalization is enabled
168+
x_residual = layer["norm0"](x_residual)
169+
170+
# Retrieve key-value compression layers
171+
key_compression, value_compression = self._get_kv_compressions(layer)
172+
173+
# Multihead Attention
174+
x_residual, _ = layer["attention"](x_residual, x_residual, x_residual)
175+
176+
# Apply residual connection
177+
x = x + x_residual
178+
179+
# Apply the linear transformation
180+
x_residual = layer["linear"](x)
181+
x = x + x_residual # Second residual connection
182+
183+
if self.last_norm:
184+
x = self.last_norm(x) # Final normalization if prenormalization is used
185+
186+
x = x.flatten(1) # Flatten from (N, J, d_model) to (N, J * d_model)
187+
return self.head(x) # Final prediction

mambular/configs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .tabm_config import DefaultTabMConfig
1111
from .tabtransformer_config import DefaultTabTransformerConfig
1212
from .tabularnn_config import DefaultTabulaRNNConfig
13+
from .autoint_config import DefaultAutoIntConfig
1314
from .base_config import BaseConfig
1415

1516
__all__ = [
17+
"DefaultAutoIntConfig",
1618
"DefaultFTTransformerConfig",
1719
"DefaultMLPConfig",
1820
"DefaultMambAttentionConfig",
@@ -25,5 +27,5 @@
2527
"DefaultTabMConfig",
2628
"DefaultTabTransformerConfig",
2729
"DefaultTabulaRNNConfig",
28-
"BaseConfig"
30+
"BaseConfig",
2931
]

mambular/configs/autoint_config.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from collections.abc import Callable
2+
from dataclasses import dataclass, field
3+
import torch.nn as nn
4+
from ..arch_utils.transformer_utils import ReGLU
5+
from .base_config import BaseConfig
6+
7+
8+
@dataclass
9+
class DefaultAutoIntConfig(BaseConfig):
10+
"""Configuration class for the AutoInt model with predefined hyperparameters.
11+
12+
Parameters
13+
----------
14+
d_model : int, default=128
15+
Dimensionality of the transformer model.
16+
n_layers : int, default=4
17+
Number of transformer layers.
18+
n_heads : int, default=8
19+
Number of attention heads in the transformer.
20+
attn_dropout : float, default=0.2
21+
Dropout rate for the attention mechanism.
22+
transformer_dim_feedforward : int, default=256
23+
Dimensionality of the feed-forward layers in the transformer.
24+
prenorm : bool, default=False
25+
Whether to apply normalization before last layer.
26+
bias : bool, default=True
27+
Whether to use bias in linear layers.
28+
cat_encoding : str, default="int"
29+
Method for encoding categorical features ('int', 'one-hot', or 'linear').
30+
kv_compression : float, default=0.5
31+
Compression ratio for key-value pairs.
32+
kv_compression_sharing : str, default='key-value'
33+
Sharing strategy for key-value compression ('headwise', or 'key-value').
34+
"""
35+
36+
# Architecture Parameters
37+
d_model: int = 128
38+
n_layers: int = 4
39+
n_heads: int = 8
40+
attn_dropout: float = 0.2
41+
fprenorm: bool = False
42+
transformer_dim_feedforward: int = 256
43+
bias: bool = True
44+
45+
use_cls: bool = False
46+
cat_encoding: str = "int"
47+
48+
kv_compression: float = 0.5
49+
kv_compression_sharing: str = "key-value"

mambular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
TabTransformerRegressor,
2626
)
2727
from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor
28+
from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor
2829

2930
__all__ = [
31+
"AutoIntClassifier",
32+
"AutoIntLSS",
33+
"AutoIntRegressor",
3034
"MLPLSS",
3135
"NDTFLSS",
3236
"NODELSS",

mambular/models/autoint.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from ..base_models.autoint import AutoInt
2+
from ..configs.autoint_config import DefaultAutoIntConfig
3+
from ..utils.docstring_generator import generate_docstring
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
7+
8+
9+
class AutoIntRegressor(SklearnBaseRegressor):
10+
__doc__ = generate_docstring(
11+
DefaultAutoIntConfig,
12+
model_description="""
13+
AutoInt regressor. This class extends the SklearnBaseRegressor
14+
class and uses the AutoInt model with the default AutoInt
15+
configuration.
16+
""",
17+
examples="""
18+
>>> from mambular.models import AutoIntRegressor
19+
>>> model = AutoIntRegressor(d_model=64, n_layers=8)
20+
>>> model.fit(X_train, y_train)
21+
>>> preds = model.predict(X_test)
22+
>>> model.evaluate(X_test, y_test)
23+
""",
24+
)
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(model=AutoInt, config=DefaultAutoIntConfig, **kwargs)
28+
29+
30+
class AutoIntClassifier(SklearnBaseClassifier):
31+
__doc__ = generate_docstring(
32+
DefaultAutoIntConfig,
33+
"""AutoInt Classifier. This class extends the SklearnBaseClassifier class
34+
and uses the AutoInt model with the default AutoInt configuration.""",
35+
examples="""
36+
>>> from mambular.models import AutoIntClassifier
37+
>>> model = AutoIntClassifier(d_model=64, n_layers=8)
38+
>>> model.fit(X_train, y_train)
39+
>>> preds = model.predict(X_test)
40+
>>> model.evaluate(X_test, y_test)
41+
""",
42+
)
43+
44+
def __init__(self, **kwargs):
45+
super().__init__(model=AutoInt, config=DefaultAutoIntConfig, **kwargs)
46+
47+
48+
class AutoIntLSS(SklearnBaseLSS):
49+
__doc__ = generate_docstring(
50+
DefaultAutoIntConfig,
51+
"""AutoInt for distributional regression.
52+
This class extends the SklearnBaseLSS class and uses the
53+
AutoInt model with the default AutoInt configuration.""",
54+
examples="""
55+
>>> from mambular.models import AutoIntLSS
56+
>>> model = AutoIntLSS(d_model=64, n_layers=8)
57+
>>> model.fit(X_train, y_train, family="normal")
58+
>>> preds = model.predict(X_test)
59+
>>> model.evaluate(X_test, y_test)
60+
""",
61+
)
62+
63+
def __init__(self, **kwargs):
64+
super().__init__(model=AutoInt, config=DefaultAutoIntConfig, **kwargs)

0 commit comments

Comments
 (0)