Skip to content

Commit 4ec70f8

Browse files
committed
adapting all basemodels to new dataset __getitem__ method
1 parent 743c214 commit 4ec70f8

12 files changed

Lines changed: 158 additions & 159 deletions

File tree

mambular/base_models/ft_transformer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
77
from ..configs.fttransformer_config import DefaultFTTransformerConfig
88
from .basemodel import BaseModel
9+
import numpy as np
910

1011

1112
class FTTransformer(BaseModel):
@@ -52,22 +53,18 @@ class FTTransformer(BaseModel):
5253

5354
def __init__(
5455
self,
55-
cat_feature_info,
56-
num_feature_info,
56+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
5757
num_classes=1,
5858
config: DefaultFTTransformerConfig = DefaultFTTransformerConfig(), # noqa: B008
5959
**kwargs,
6060
):
6161
super().__init__(config=config, **kwargs)
62-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
62+
self.save_hyperparameters(ignore=["feature_information"])
6363
self.returns_ensemble = False
64-
self.cat_feature_info = cat_feature_info
65-
self.num_feature_info = num_feature_info
6664

6765
# embedding layer
6866
self.embedding_layer = EmbeddingLayer(
69-
num_feature_info=num_feature_info,
70-
cat_feature_info=cat_feature_info,
67+
*feature_information,
7168
config=config,
7269
)
7370

@@ -87,25 +84,23 @@ def __init__(
8784
)
8885

8986
# pooling
90-
n_inputs = len(num_feature_info) + len(cat_feature_info)
87+
n_inputs = np.sum([len(info) for info in feature_information])
9188
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
9289

93-
def forward(self, num_features, cat_features):
90+
def forward(self, *data):
9491
"""Defines the forward pass of the model.
9592
9693
Parameters
9794
----------
98-
num_features : Tensor
99-
Tensor containing the numerical features.
100-
cat_features : Tensor
101-
Tensor containing the categorical features.
95+
data : tuple
96+
Input tuple of tensors of num_features, cat_features, embeddings.
10297
10398
Returns
10499
-------
105100
Tensor
106101
The output predictions of the model.
107102
"""
108-
x = self.embedding_layer(num_features, cat_features)
103+
x = self.embedding_layer(*data)
109104

110105
x = self.encoder(x)
111106

mambular/base_models/mambatab.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..arch_utils.mamba_utils.mamba_arch import Mamba
66
from ..arch_utils.mamba_utils.mamba_original import MambaOriginal
77
from ..arch_utils.mlp_utils import MLPhead
8+
from ..utils.get_feature_dimensions import get_feature_dimensions
89
from ..configs.mambatab_config import DefaultMambaTabConfig
910
from .basemodel import BaseModel
1011

@@ -56,23 +57,16 @@ class MambaTab(BaseModel):
5657

5758
def __init__(
5859
self,
59-
cat_feature_info,
60-
num_feature_info,
60+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
6161
num_classes=1,
6262
config: DefaultMambaTabConfig = DefaultMambaTabConfig(), # noqa: B008
6363
**kwargs,
6464
):
6565
super().__init__(config=config, **kwargs)
66-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
66+
self.save_hyperparameters(ignore=["feature_information"])
6767

68-
input_dim = 0
69-
for feature_name, input_shape in num_feature_info.items():
70-
input_dim += 1
71-
for feature_name, input_shape in cat_feature_info.items():
72-
input_dim += 1
68+
input_dim = get_feature_dimensions(*feature_information)
7369

74-
self.cat_feature_info = cat_feature_info
75-
self.num_feature_info = num_feature_info
7670
self.returns_ensemble = False
7771

7872
self.initial_layer = nn.Linear(input_dim, config.d_model)
@@ -93,9 +87,20 @@ def __init__(
9387
else:
9488
self.mamba = MambaOriginal(config)
9589

96-
def forward(self, num_features, cat_features):
97-
x = num_features + cat_features
98-
x = torch.cat(x, dim=1)
90+
def forward(self, *data):
91+
"""Forward pass of the Mambatab model
92+
93+
Parameters
94+
----------
95+
data : tuple
96+
Input tuple of tensors of num_features, cat_features, embeddings.
97+
98+
Returns
99+
-------
100+
torch.Tensor
101+
Output tensor.
102+
"""
103+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
99104

100105
x = self.initial_layer(x)
101106
if self.axis == 1:

mambular/base_models/mambattn.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
import numpy as np
33
from ..arch_utils.get_norm_fn import get_normalization_layer
44
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
55
from ..arch_utils.mamba_utils.mambattn_arch import MambAttn
@@ -52,14 +52,15 @@ class MambAttention(BaseModel):
5252

5353
def __init__(
5454
self,
55-
cat_feature_info,
56-
num_feature_info,
55+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
5756
num_classes=1,
5857
config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), # noqa: B008
5958
**kwargs,
6059
):
6160
super().__init__(config=config, **kwargs)
62-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
61+
self.save_hyperparameters(ignore=["feature_information"])
62+
63+
self.returns_ensemble = False
6364

6465
try:
6566
self.pooling_method = self.hparams.pooling_method
@@ -76,8 +77,7 @@ def __init__(
7677

7778
# embedding layer
7879
self.embedding_layer = EmbeddingLayer(
79-
num_feature_info=num_feature_info,
80-
cat_feature_info=cat_feature_info,
80+
*feature_information,
8181
config=config,
8282
)
8383

@@ -101,25 +101,23 @@ def __init__(
101101
self.perm = torch.randperm(self.embedding_layer.seq_len)
102102

103103
# pooling
104-
n_inputs = len(num_feature_info) + len(cat_feature_info)
104+
n_inputs = np.sum([len(info) for info in feature_information])
105105
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
106106

107-
def forward(self, num_features, cat_features):
107+
def forward(self, *data):
108108
"""Defines the forward pass of the model.
109109
110110
Parameters
111111
----------
112-
num_features : Tensor
113-
Tensor containing the numerical features.
114-
cat_features : Tensor
115-
Tensor containing the categorical features.
112+
data : tuple
113+
Input tuple of tensors of num_features, cat_features, embeddings.
116114
117115
Returns
118116
-------
119-
Tensor
120-
The output predictions of the model.
117+
torch.Tensor
118+
Output tensor.
121119
"""
122-
x = self.embedding_layer(num_features, cat_features)
120+
x = self.embedding_layer(*data)
123121

124122
if self.shuffle_embeddings:
125123
x = x[:, self.perm, :]

mambular/base_models/mambular.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..arch_utils.mlp_utils import MLPhead
77
from ..configs.mambular_config import DefaultMambularConfig
88
from .basemodel import BaseModel
9+
import numpy as np
910

1011

1112
class Mambular(BaseModel):
@@ -52,21 +53,19 @@ class Mambular(BaseModel):
5253

5354
def __init__(
5455
self,
55-
cat_feature_info,
56-
num_feature_info,
56+
feature_information: tuple, # Expecting (cat_feature_info, num_feature_info, embedding_feature_info)
5757
num_classes=1,
5858
config: DefaultMambularConfig = DefaultMambularConfig(), # noqa: B008
5959
**kwargs,
6060
):
6161
super().__init__(config=config, **kwargs)
62-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
62+
self.save_hyperparameters(ignore=["feature_information"])
6363

6464
self.returns_ensemble = False
6565

6666
# embedding layer
6767
self.embedding_layer = EmbeddingLayer(
68-
num_feature_info=num_feature_info,
69-
cat_feature_info=cat_feature_info,
68+
*feature_information,
7069
config=config,
7170
)
7271

@@ -85,25 +84,23 @@ def __init__(
8584
self.perm = torch.randperm(self.embedding_layer.seq_len)
8685

8786
# pooling
88-
n_inputs = len(num_feature_info) + len(cat_feature_info)
87+
n_inputs = np.sum([len(info) for info in feature_information])
8988
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
9089

91-
def forward(self, num_features, cat_features):
90+
def forward(self, *data):
9291
"""Defines the forward pass of the model.
9392
9493
Parameters
9594
----------
96-
num_features : Tensor
97-
Tensor containing the numerical features.
98-
cat_features : Tensor
99-
Tensor containing the categorical features.
95+
data : tuple
96+
Input tuple of tensors of num_features, cat_features, embeddings.
10097
10198
Returns
10299
-------
103100
Tensor
104101
The output predictions of the model.
105102
"""
106-
x = self.embedding_layer(num_features, cat_features)
103+
x = self.embedding_layer(*data)
107104

108105
if self.hparams.shuffle_embeddings:
109106
x = x[:, self.perm, :]

mambular/base_models/mlp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import torch
22
import torch.nn as nn
3-
3+
import numpy as np
44
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
55
from ..configs.mlp_config import DefaultMLPConfig
66
from ..utils.get_feature_dimensions import get_feature_dimensions
77
from .basemodel import BaseModel
8-
import numpy as np
98

109

1110
class MLP(BaseModel):
@@ -58,7 +57,7 @@ class MLP(BaseModel):
5857

5958
def __init__(
6059
self,
61-
feature_information: tuple, # Expecting (cat_feature_info, num_feature_info, embedding_feature_info)
60+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
6261
num_classes: int = 1,
6362
config: DefaultMLPConfig = DefaultMLPConfig(), # noqa: B008
6463
**kwargs,
@@ -71,8 +70,6 @@ def __init__(
7170
# Initialize layers
7271
self.layers = nn.ModuleList()
7372

74-
input_dim = get_feature_dimensions(*feature_information)
75-
7673
if self.hparams.use_embeddings:
7774
self.embedding_layer = EmbeddingLayer(
7875
*feature_information,
@@ -81,6 +78,8 @@ def __init__(
8178
input_dim = np.sum(
8279
[len(info) * self.hparams.d_model for info in feature_information]
8380
)
81+
else:
82+
input_dim = get_feature_dimensions(*feature_information)
8483

8584
# Input layer
8685
self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0]))

mambular/base_models/ndtf.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,17 @@ class NDTF(BaseModel):
5454

5555
def __init__(
5656
self,
57-
cat_feature_info,
58-
num_feature_info,
57+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
5958
num_classes: int = 1,
6059
config: DefaultNDTFConfig = DefaultNDTFConfig(), # noqa: B008
6160
**kwargs,
6261
):
6362
super().__init__(config=config, **kwargs)
64-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
63+
self.save_hyperparameters(ignore=["feature_information"])
6564

66-
self.cat_feature_info = cat_feature_info
67-
self.num_feature_info = num_feature_info
6865
self.returns_ensemble = False
6966

70-
input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
67+
input_dim = get_feature_dimensions(*feature_information)
7168

7269
self.input_dimensions = [input_dim]
7370

@@ -78,10 +75,13 @@ def __init__(
7875
[
7976
NeuralDecisionTree(
8077
input_dim=self.input_dimensions[idx],
81-
depth=np.random.randint(self.hparams.min_depth, self.hparams.max_depth),
78+
depth=np.random.randint(
79+
self.hparams.min_depth, self.hparams.max_depth
80+
),
8281
output_dim=num_classes,
8382
lamda=self.hparams.lamda,
84-
temperature=self.hparams.temperature + np.abs(np.random.normal(0, 0.1)),
83+
temperature=self.hparams.temperature
84+
+ np.abs(np.random.normal(0, 0.1)),
8585
node_sampling=self.hparams.node_sampling,
8686
)
8787
for idx in range(self.hparams.n_ensembles)
@@ -103,21 +103,20 @@ def __init__(
103103
requires_grad=True,
104104
)
105105

106-
def forward(self, num_features, cat_features) -> torch.Tensor:
106+
def forward(self, *data) -> torch.Tensor:
107107
"""Forward pass of the NDTF model.
108108
109109
Parameters
110110
----------
111-
x : torch.Tensor
112-
Input tensor.
111+
data : tuple
112+
Input tuple of tensors of num_features, cat_features, embeddings.
113113
114114
Returns
115115
-------
116116
torch.Tensor
117117
Output tensor.
118118
"""
119-
x = num_features + cat_features
120-
x = torch.cat(x, dim=1)
119+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
121120
x = self.conv_layer(x.unsqueeze(2))
122121
x = x.transpose(1, 2).squeeze(-1)
123122

@@ -131,21 +130,20 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
131130

132131
return preds @ self.tree_weights
133132

134-
def penalty_forward(self, num_features, cat_features) -> torch.Tensor:
133+
def penalty_forward(self, *data) -> torch.Tensor:
135134
"""Forward pass of the NDTF model.
136135
137136
Parameters
138137
----------
139-
x : torch.Tensor
140-
Input tensor.
138+
data : tuple
139+
Input tuple of tensors of num_features, cat_features, embeddings.
141140
142141
Returns
143142
-------
144143
torch.Tensor
145144
Output tensor.
146145
"""
147-
x = num_features + cat_features
148-
x = torch.cat(x, dim=1)
146+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
149147
x = self.conv_layer(x.unsqueeze(2))
150148
x = x.transpose(1, 2).squeeze(-1)
151149

0 commit comments

Comments
 (0)