Skip to content

Commit ffd1f9c

Browse files
committed
include new model: Tangos
1 parent ec989f1 commit ffd1f9c

2 files changed

Lines changed: 230 additions & 0 deletions

File tree

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
from .autoint import AutoInt
1414
from .trompt import Trompt
1515
from .enode import ENODE
16+
from .tangos import Tangos
1617

1718
__all__ = [
19+
"Tangos",
1820
"ENODE",
1921
"Trompt",
2022
"AutoInt",

mambular/base_models/tangos.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
5+
from ..configs.tangos_config import DefaultTangosConfig
6+
from ..utils.get_feature_dimensions import get_feature_dimensions
7+
from .utils.basemodel import BaseModel
8+
9+
10+
class Tangos(BaseModel):
11+
"""
12+
A Multi-Layer Perceptron (MLP) model with optional GLU activation, batch normalization, layer normalization, and dropout.
13+
It includes a penalty term for specialization and orthogonality.
14+
15+
Parameters
16+
----------
17+
feature_information : tuple
18+
A tuple containing feature information for numerical and categorical features.
19+
num_classes : int, optional (default=1)
20+
The number of output classes.
21+
config : DefaultTangosConfig, optional (default=DefaultTangosConfig())
22+
Configuration object defining model hyperparameters.
23+
**kwargs : dict
24+
Additional arguments for the base model.
25+
26+
Attributes
27+
----------
28+
returns_ensemble : bool
29+
Whether the model returns an ensemble of predictions.
30+
lamda1 : float
31+
Regularization weight for the specialization loss.
32+
lamda2 : float
33+
Regularization weight for the orthogonality loss.
34+
subsample : float
35+
Proportion of neuron pairs to use for orthogonality loss calculation.
36+
embedding_layer : EmbeddingLayer or None
37+
Optional embedding layer for categorical features.
38+
layers : nn.ModuleList
39+
The main MLP layers including linear, normalization, and activation layers.
40+
head : nn.Linear
41+
The final output layer.
42+
"""
43+
def __init__(
44+
self,
45+
feature_information: tuple,
46+
num_classes=1,
47+
config: DefaultTangosConfig = DefaultTangosConfig(),
48+
**kwargs
49+
):
50+
super().__init__(config=config, **kwargs)
51+
self.save_hyperparameters(ignore=["feature_information"])
52+
self.returns_ensemble = False
53+
54+
self.lamda1 = config.lamda1
55+
self.lamda2 = config.lamda2
56+
self.subsample = config.subsample
57+
58+
input_dim = get_feature_dimensions(*feature_information)
59+
60+
# Initialize layers
61+
self.layers = nn.ModuleList()
62+
63+
# Input layer
64+
self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0]))
65+
if self.hparams.batch_norm:
66+
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[0]))
67+
68+
if self.hparams.use_glu:
69+
self.layers.append(nn.GLU())
70+
else:
71+
self.layers.append(self.hparams.activation)
72+
if self.hparams.dropout > 0.0:
73+
self.layers.append(nn.Dropout(self.hparams.dropout))
74+
75+
# Hidden layers
76+
for i in range(1, len(self.hparams.layer_sizes)):
77+
self.layers.append(
78+
nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i])
79+
)
80+
if self.hparams.batch_norm:
81+
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[i]))
82+
if self.hparams.layer_norm:
83+
self.layers.append(nn.LayerNorm(self.hparams.layer_sizes[i]))
84+
if self.hparams.use_glu:
85+
self.layers.append(nn.GLU())
86+
else:
87+
self.layers.append(self.hparams.activation)
88+
if self.hparams.dropout > 0.0:
89+
self.layers.append(nn.Dropout(self.hparams.dropout))
90+
91+
# Output layer
92+
self.head = nn.Linear(self.hparams.layer_sizes[-1], num_classes)
93+
94+
def repr_forward(self, x) -> torch.Tensor:
95+
"""
96+
Computes the forward pass for feature representations.
97+
98+
This method processes the input through the MLP layers, optionally using
99+
skip connections.
100+
101+
Parameters
102+
----------
103+
x : torch.Tensor
104+
Input tensor of shape (batch_size, feature_dim).
105+
106+
Returns
107+
-------
108+
torch.Tensor
109+
Output tensor after passing through the representation layers.
110+
"""
111+
112+
x = x.unsqueeze(0)
113+
114+
for i in range(len(self.layers)):
115+
if isinstance(self.layers[i], nn.Linear):
116+
out = self.layers[i](x)
117+
if self.hparams.skip_connections and x.shape == out.shape:
118+
x = x + out
119+
else:
120+
x = out
121+
else:
122+
x = self.layers[i](x)
123+
124+
return x
125+
126+
def forward(self, *data) -> torch.Tensor:
127+
"""
128+
Performs a forward pass of the MLP model.
129+
130+
This method concatenates all input tensors before applying MLP layers.
131+
132+
Parameters
133+
----------
134+
data : tuple
135+
A tuple containing lists of numerical, categorical, and embedded feature tensors.
136+
137+
Returns
138+
-------
139+
torch.Tensor
140+
The output tensor of shape (batch_size, num_classes).
141+
"""
142+
143+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
144+
145+
for i in range(len(self.layers)):
146+
if isinstance(self.layers[i], nn.Linear):
147+
out = self.layers[i](x)
148+
if self.hparams.skip_connections and x.shape == out.shape:
149+
x = x + out
150+
else:
151+
x = out
152+
else:
153+
x = self.layers[i](x)
154+
x = self.head(x)
155+
return x
156+
157+
def penalty_forward(self, *data):
158+
"""
159+
Computes both the model predictions and a penalty term.
160+
161+
The penalty term includes:
162+
- **Specialization loss**: Measures feature importance concentration.
163+
- **Orthogonality loss**: Encourages diversity among learned features.
164+
165+
The method uses `jacrev` to compute the Jacobian of the representation function.
166+
167+
Parameters
168+
----------
169+
data : tuple
170+
A tuple containing lists of numerical, categorical, and embedded feature tensors.
171+
172+
Returns
173+
-------
174+
tuple
175+
- predictions : torch.Tensor
176+
Model predictions of shape (batch_size, num_classes).
177+
- penalty : torch.Tensor
178+
The computed penalty term for regularization.
179+
"""
180+
181+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
182+
batch_size = x.shape[0]
183+
subsample = np.int32(self.subsample*batch_size)
184+
185+
# Flatten before passing to jacrev
186+
flat_data = torch.cat([t for tensors in data for t in tensors], dim=1)
187+
188+
# Compute Jacobian
189+
jacobian = torch.func.vmap(torch.func.jacrev(self.repr_forward), randomness="different")(flat_data)
190+
jacobian = jacobian.squeeze()
191+
192+
neuron_attr = jacobian.swapaxes(0, 1)
193+
h_dim = neuron_attr.shape[0]
194+
if len(neuron_attr.shape) > 3:
195+
# h_dim x batch_size x features
196+
neuron_attr = neuron_attr.flatten(start_dim=2)
197+
198+
# calculate specialization loss component
199+
spec_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2])
200+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
201+
orth_loss = torch.tensor(0.0, requires_grad=True).to(x.device)
202+
# apply subsampling routine for orthogonalization loss
203+
if self.subsample > 0 and self.subsample < h_dim * (h_dim - 1) / 2:
204+
tensor_pairs = [
205+
list(np.random.choice(h_dim, size=(2), replace=False))
206+
for i in range(subsample)
207+
]
208+
for tensor_pair in tensor_pairs:
209+
pairwise_corr = cos(
210+
neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]
211+
).norm(p=1)
212+
orth_loss = orth_loss + pairwise_corr
213+
214+
orth_loss = orth_loss / (batch_size * self.subsample)
215+
else:
216+
for neuron_i in range(1, h_dim):
217+
for neuron_j in range(0, neuron_i):
218+
pairwise_corr = cos(
219+
neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]
220+
).norm(p=1)
221+
orth_loss = orth_loss + pairwise_corr
222+
num_pairs = h_dim * (h_dim - 1) / 2
223+
orth_loss = orth_loss / (batch_size * num_pairs)
224+
225+
penalty = self.lamda1 * spec_loss + self.lamda2 * orth_loss
226+
predictions = self.forward(*data)
227+
228+
return predictions, penalty

0 commit comments

Comments
 (0)