Skip to content

Commit e6b90dc

Browse files
authored
Merge pull request #99 from basf/minor_fix
Minor fix
2 parents 8933cf7 + 71f35e6 commit e6b90dc

14 files changed

Lines changed: 133 additions & 90 deletions

mambular/arch_utils/embedding_layer.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
layer_norm_after_embedding=False,
1313
use_cls=False,
1414
cls_position=0,
15+
cat_encoding="int",
1516
):
1617
"""
1718
Embedding layer that handles numerical and categorical embeddings.
@@ -56,15 +57,23 @@ def __init__(
5657
]
5758
)
5859

59-
self.cat_embeddings = nn.ModuleList(
60-
[
61-
nn.Sequential(
62-
nn.Embedding(num_categories + 1, d_model),
63-
self.embedding_activation,
60+
self.cat_embeddings = nn.ModuleList()
61+
for feature_name, num_categories in cat_feature_info.items():
62+
if cat_encoding == "int":
63+
self.cat_embeddings.append(
64+
nn.Sequential(
65+
nn.Embedding(num_categories + 1, d_model),
66+
self.embedding_activation,
67+
)
68+
)
69+
elif cat_encoding == "one-hot":
70+
self.cat_embeddings.append(
71+
nn.Sequential(
72+
OneHotEncoding(num_categories),
73+
nn.Linear(num_categories, d_model, bias=False),
74+
self.embedding_activation,
75+
)
6476
)
65-
for feature_name, num_categories in cat_feature_info.items()
66-
]
67-
)
6877

6978
if self.use_cls:
7079
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
@@ -143,3 +152,12 @@ def forward(self, num_features=None, cat_features=None):
143152
)
144153

145154
return x
155+
156+
157+
class OneHotEncoding(nn.Module):
158+
def __init__(self, num_categories):
159+
super(OneHotEncoding, self).__init__()
160+
self.num_categories = num_categories
161+
162+
def forward(self, x):
163+
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()

mambular/base_models/ft_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ def __init__(
132132
embedding_activation=self.hparams.get(
133133
"embedding_activation", config.embedding_activation
134134
),
135-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
135+
layer_norm_after_embedding=self.hparams.get(
136+
"layer_norm_after_embedding", config.layer_norm_after_embedding
137+
),
136138
use_cls=True,
137139
cls_position=0,
140+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
138141
)
139142

140143
head_activation = self.hparams.get("head_activation", config.head_activation)

mambular/base_models/lightning_wrapper.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
else:
8383
output_dim = num_classes
8484

85-
self.model = model_class(
85+
self.base_model = model_class(
8686
config=config,
8787
num_feature_info=num_feature_info,
8888
cat_feature_info=cat_feature_info,
@@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
107107
Model output.
108108
"""
109109

110-
return self.model.forward(num_features, cat_features)
110+
return self.base_model.forward(num_features, cat_features)
111111

112112
def compute_loss(self, predictions, y_true):
113113
"""
@@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
168168
prog_bar=True,
169169
logger=True,
170170
)
171-
elif isinstance(self.loss_fct, nn.MSELoss):
172-
rmse = torch.sqrt(loss)
173-
self.log(
174-
"train_rmse",
175-
rmse,
176-
on_step=True,
177-
on_epoch=True,
178-
prog_bar=True,
179-
logger=True,
180-
)
181171

182172
return loss
183173

@@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
205195
self.log(
206196
"val_loss",
207197
val_loss,
208-
on_step=True,
198+
on_step=False,
209199
on_epoch=True,
210200
prog_bar=True,
211201
logger=True,
@@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
218208
self.log(
219209
"val_acc",
220210
acc,
221-
on_step=True,
222-
on_epoch=True,
223-
prog_bar=True,
224-
logger=True,
225-
)
226-
elif isinstance(self.loss_fct, nn.MSELoss):
227-
rmse = torch.sqrt(val_loss)
228-
self.log(
229-
"val_rmse",
230-
rmse,
231-
on_step=True,
211+
on_step=False,
232212
on_epoch=True,
233213
prog_bar=True,
234214
logger=True,
@@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
272252
self.log(
273253
"test_acc",
274254
acc,
275-
on_step=True,
276-
on_epoch=True,
277-
prog_bar=True,
278-
logger=True,
279-
)
280-
elif isinstance(self.loss_fct, nn.MSELoss):
281-
rmse = torch.sqrt(test_loss)
282-
self.log(
283-
"test_rmse",
284-
rmse,
285-
on_step=True,
255+
on_step=False,
286256
on_epoch=True,
287257
prog_bar=True,
288258
logger=True,
@@ -300,7 +270,7 @@ def configure_optimizers(self):
300270
A dictionary containing the optimizer and lr_scheduler configurations.
301271
"""
302272
optimizer = torch.optim.Adam(
303-
self.model.parameters(),
273+
self.base_model.parameters(),
304274
lr=self.lr,
305275
weight_decay=self.weight_decay,
306276
)

mambular/base_models/mambular.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,12 @@ def __init__(
150150
embedding_activation=self.hparams.get(
151151
"embedding_activation", config.embedding_activation
152152
),
153-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
154-
use_cls=True,
155-
cls_position=0,
153+
layer_norm_after_embedding=self.hparams.get(
154+
"layer_norm_after_embedding", config.layer_norm_after_embedding
155+
),
156+
use_cls=False,
157+
cls_position=-1,
158+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
156159
)
157160

158161
head_activation = self.hparams.get("head_activation", config.head_activation)

mambular/base_models/tabtransformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,12 @@ def __init__(
139139
embedding_activation=self.hparams.get(
140140
"embedding_activation", config.embedding_activation
141141
),
142-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
142+
layer_norm_after_embedding=self.hparams.get(
143+
"layer_norm_after_embedding", config.layer_norm_after_embedding
144+
),
143145
use_cls=True,
144146
cls_position=0,
147+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
145148
)
146149

147150
head_activation = self.hparams.get("head_activation", config.head_activation)

mambular/configs/fttransformer_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DefaultFTTransformerConfig:
5858
Epsilon value for layer normalization.
5959
transformer_dim_feedforward : int, default=512
6060
Dimensionality of the feed-forward layers in the transformer.
61+
cat_encoding : str, default="int"
62+
whether to use integer encoding or one-hot encoding for cat features.
6163
"""
6264

6365
lr: float = 1e-04
@@ -84,4 +86,4 @@ class DefaultFTTransformerConfig:
8486
transformer_activation: callable = ReGLU()
8587
layer_norm_eps: float = 1e-05
8688
transformer_dim_feedforward: int = 256
87-
numerical_embedding: str = "ple"
89+
cat_encoding: str = "int"

mambular/configs/mambular_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ class DefaultMambularConfig:
7676
layer_norm_eps : float, default=1e-05
7777
Epsilon value for layer normalization.
7878
AD_weight_decay : bool, default=False
79-
whether weight decay is also applied to A-D matrices
79+
whether weight decay is also applied to A-D matrices.
8080
BC_layer_norm: bool, default=True
81-
whether to apply layer normalization to B-C matrices
81+
whether to apply layer normalization to B-C matrices.
82+
cat_encoding : str, default="int"
83+
whether to use integer encoding or one-hot encoding for cat features.
8284
"""
8385

8486
lr: float = 1e-04
@@ -116,3 +118,4 @@ class DefaultMambularConfig:
116118
layer_norm_eps: float = 1e-05
117119
AD_weight_decay: bool = False
118120
BC_layer_norm: bool = True
121+
cat_encoding: str = "int"

mambular/configs/tabtransformer_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DefaultTabTransformerConfig:
5858
Epsilon value for layer normalization.
5959
transformer_dim_feedforward : int, default=512
6060
Dimensionality of the feed-forward layers in the transformer.
61+
cat_encoding : str, default="int"
62+
whether to use integer encoding or one-hot encoding for cat features.
6163
"""
6264

6365
lr: float = 1e-04
@@ -84,3 +86,4 @@ class DefaultTabTransformerConfig:
8486
transformer_activation: callable = ReGLU()
8587
layer_norm_eps: float = 1e-05
8688
transformer_dim_feedforward: int = 512
89+
cat_encoding: str = "int"

mambular/models/fttransformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class FTTransformerRegressor(SklearnBaseRegressor):
6464
Epsilon value for layer normalization.
6565
transformer_dim_feedforward : int, default=512
6666
Dimensionality of the feed-forward layers in the transformer.
67+
cat_encoding : str, default="int"
68+
whether to use integer encoding or one-hot encoding for cat features.
6769
n_bins : int, default=50
6870
The number of bins to use for numerical feature binning. This parameter is relevant
6971
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -171,6 +173,8 @@ class FTTransformerClassifier(SklearnBaseClassifier):
171173
Epsilon value for layer normalization.
172174
transformer_dim_feedforward : int, default=512
173175
Dimensionality of the feed-forward layers in the transformer.
176+
cat_encoding : str, default="int"
177+
whether to use integer encoding or one-hot encoding for cat features.
174178
n_bins : int, default=50
175179
The number of bins to use for numerical feature binning. This parameter is relevant
176180
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -278,6 +282,8 @@ class FTTransformerLSS(SklearnBaseLSS):
278282
Epsilon value for layer normalization.
279283
transformer_dim_feedforward : int, default=512
280284
Dimensionality of the feed-forward layers in the transformer.
285+
cat_encoding : str, default="int"
286+
whether to use integer encoding or one-hot encoding for cat features.
281287
n_bins : int, default=50
282288
The number of bins to use for numerical feature binning. This parameter is relevant
283289
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.

mambular/models/mambular.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ class MambularRegressor(SklearnBaseRegressor):
7979
Whether to append a cls to the end of each 'sequence'.
8080
shuffle_embeddings : bool, default=False.
8181
Whether to shuffle the embeddings before being passed to the Mamba layers.
82+
layer_norm_eps : float, default=1e-05
83+
Epsilon value for layer normalization.
84+
AD_weight_decay : bool, default=False
85+
whether weight decay is also applied to A-D matrices.
86+
BC_layer_norm: bool, default=True
87+
whether to apply layer normalization to B-C matrices.
88+
cat_encoding : str, default="int"
89+
whether to use integer encoding or one-hot encoding for cat features.
8290
n_bins : int, default=50
8391
The number of bins to use for numerical feature binning. This parameter is relevant
8492
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -198,6 +206,14 @@ class MambularClassifier(SklearnBaseClassifier):
198206
Whether to use learnable feature interactions before passing through mamba blocks.
199207
shuffle_embeddings : bool, default=False.
200208
Whether to shuffle the embeddings before being passed to the Mamba layers.
209+
layer_norm_eps : float, default=1e-05
210+
Epsilon value for layer normalization.
211+
AD_weight_decay : bool, default=False
212+
whether weight decay is also applied to A-D matrices.
213+
BC_layer_norm: bool, default=True
214+
whether to apply layer normalization to B-C matrices.
215+
cat_encoding : str, default="int"
216+
whether to use integer encoding or one-hot encoding for cat features.
201217
n_bins : int, default=50
202218
The number of bins to use for numerical feature binning. This parameter is relevant
203219
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -320,6 +336,14 @@ class MambularLSS(SklearnBaseLSS):
320336
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
321337
shuffle_embeddings : bool, default=False.
322338
Whether to shuffle the embeddings before being passed to the Mamba layers.
339+
layer_norm_eps : float, default=1e-05
340+
Epsilon value for layer normalization.
341+
AD_weight_decay : bool, default=False
342+
whether weight decay is also applied to A-D matrices.
343+
BC_layer_norm: bool, default=True
344+
whether to apply layer normalization to B-C matrices.
345+
cat_encoding : str, default="int"
346+
whether to use integer encoding or one-hot encoding for cat features.
323347
numerical_preprocessing : str, default="ple"
324348
The preprocessing strategy for numerical features. Valid options are
325349
'binning', 'one_hot', 'standardization', and 'normalization'.

0 commit comments

Comments
 (0)