Skip to content

Commit 4a76db9

Browse files
authored
Merge pull request #202 from basf/util_fixes
Util fixes
2 parents b398d13 + 75c2d1b commit 4a76db9

14 files changed

Lines changed: 732 additions & 434 deletions

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
2222
super().__init__()
2323

2424
self.d_model = getattr(config, "d_model", 128)
25-
self.embedding_activation = getattr(config, "embedding_activation", nn.Identity())
26-
self.layer_norm_after_embedding = getattr(config, "layer_norm_after_embedding", False)
25+
self.embedding_activation = getattr(
26+
config, "embedding_activation", nn.Identity()
27+
)
28+
self.layer_norm_after_embedding = getattr(
29+
config, "layer_norm_after_embedding", False
30+
)
2731
self.use_cls = getattr(config, "use_cls", False)
2832
self.cls_position = getattr(config, "cls_position", 0)
2933
self.embedding_dropout = (
@@ -71,22 +75,26 @@ def __init__(self, num_feature_info, cat_feature_info, config):
7175
# for splines and other embeddings
7276
# splines followed by linear if n_knots actual knots is less than the defined knots
7377
else:
74-
raise ValueError("Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'.")
78+
raise ValueError(
79+
"Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'."
80+
)
7581

7682
self.cat_embeddings = nn.ModuleList(
7783
[
78-
nn.Sequential(
79-
nn.Embedding(feature_info["categories"] + 1, self.d_model),
80-
self.embedding_activation,
81-
)
82-
if feature_info["dimension"] == 1
83-
else nn.Sequential(
84-
nn.Linear(
85-
feature_info["dimension"],
86-
self.d_model,
87-
bias=self.embedding_bias,
88-
),
89-
self.embedding_activation,
84+
(
85+
nn.Sequential(
86+
nn.Embedding(feature_info["categories"] + 1, self.d_model),
87+
self.embedding_activation,
88+
)
89+
if feature_info["dimension"] == 1
90+
else nn.Sequential(
91+
nn.Linear(
92+
feature_info["dimension"],
93+
self.d_model,
94+
bias=self.embedding_bias,
95+
),
96+
self.embedding_activation,
97+
)
9098
)
9199
for feature_name, feature_info in cat_feature_info.items()
92100
]
@@ -124,17 +132,17 @@ def forward(self, num_features=None, cat_features=None):
124132
# Class token initialization
125133
if self.use_cls:
126134
batch_size = (
127-
cat_features[0].size( # type: ignore
128-
0
129-
)
135+
cat_features[0].size(0) # type: ignore
130136
if cat_features != []
131137
else num_features[0].size(0) # type: ignore
132138
) # type: ignore
133139
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
134140

135141
# Process categorical embeddings
136142
if self.cat_embeddings and cat_features is not None:
137-
cat_embeddings = [emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)]
143+
cat_embeddings = [
144+
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
145+
]
138146
cat_embeddings = torch.stack(cat_embeddings, dim=1)
139147
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
140148
if self.layer_norm_after_embedding:
@@ -182,7 +190,9 @@ def forward(self, num_features=None, cat_features=None):
182190
elif self.cls_position == 1:
183191
x = torch.cat([x, cls_tokens], dim=1) # type: ignore
184192
else:
185-
raise ValueError("Invalid cls_position value. It should be either 0 or 1.")
193+
raise ValueError(
194+
"Invalid cls_position value. It should be either 0 or 1."
195+
)
186196

187197
# Apply dropout to embeddings if specified in config
188198
if self.embedding_dropout is not None:

mambular/base_models/basemodel.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]):
3333
List of keys to ignore while saving hyperparameters, by default [].
3434
"""
3535
# Filter the config and extra hparams for ignored keys
36-
config_hparams = {k: v for k, v in vars(self.config).items() if k not in ignore} if self.config else {}
36+
config_hparams = (
37+
{k: v for k, v in vars(self.config).items() if k not in ignore}
38+
if self.config
39+
else {}
40+
)
3741
extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore}
3842
config_hparams.update(extra_hparams)
3943

@@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs):
148152
"""Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method."""
149153
if self.hparams.pooling_method == "learned_flatten":
150154
# Flattening + Linear layer
151-
self.learned_flatten_pooling = nn.Linear(n_inputs * config.dim_feedforward, config.dim_feedforward)
155+
self.learned_flatten_pooling = nn.Linear(
156+
n_inputs * config.dim_feedforward, config.dim_feedforward
157+
)
152158

153159
elif self.hparams.pooling_method == "attention":
154160
# Attention-based pooling with learnable attention weights
@@ -216,3 +222,29 @@ def pool_sequence(self, out):
216222
return out
217223
else:
218224
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")
225+
226+
def encode(self, num_features, cat_features):
227+
if not hasattr(self, "embedding_layer"):
228+
raise ValueError("The model does not have an embedding layer")
229+
230+
# Check if at least one of the contextualized embedding methods exists
231+
valid_layers = ["mamba", "rnn", "lstm", "encoder"]
232+
available_layer = next(
233+
(attr for attr in valid_layers if hasattr(self, attr)), None
234+
)
235+
236+
if not available_layer:
237+
raise ValueError("The model does not generate contextualized embeddings")
238+
239+
# Get the actual layer and call it
240+
x = self.embedding_layer(num_features=num_features, cat_features=cat_features)
241+
242+
if getattr(self.hparams, "shuffle_embeddings", False):
243+
x = x[:, self.perm, :]
244+
245+
layer = getattr(self, available_layer)
246+
if available_layer == "rnn":
247+
embeddings, _ = layer(x)
248+
else:
249+
embeddings = layer(x)
250+
return embeddings

mambular/base_models/lightning_wrapper.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import lightning as pl
44
import torch
55
import torch.nn as nn
6-
import torchmetrics
76

87

98
class TaskModel(pl.LightningModule):
@@ -41,6 +40,8 @@ def __init__(
4140
pruning_epoch=5,
4241
optimizer_type: str = "Adam",
4342
optimizer_args: dict | None = None,
43+
train_metrics: dict[str, Callable] | None = None,
44+
val_metrics: dict[str, Callable] | None = None,
4445
**kwargs,
4546
):
4647
super().__init__()
@@ -53,6 +54,10 @@ def __init__(
5354
self.pruning_epoch = pruning_epoch
5455
self.val_losses = []
5556

57+
# Store custom metrics
58+
self.train_metrics = train_metrics or {}
59+
self.val_metrics = val_metrics or {}
60+
5661
self.optimizer_params = {
5762
k.replace("optimizer_", ""): v
5863
for k, v in optimizer_args.items() # type: ignore
@@ -65,16 +70,10 @@ def __init__(
6570
if num_classes == 2:
6671
if not self.loss_fct:
6772
self.loss_fct = nn.BCEWithLogitsLoss()
68-
self.acc = torchmetrics.Accuracy(task="binary")
69-
self.auroc = torchmetrics.AUROC(task="binary")
70-
self.precision = torchmetrics.Precision(task="binary")
7173
self.num_classes = 1
7274
elif num_classes > 2:
7375
if not self.loss_fct:
7476
self.loss_fct = nn.CrossEntropyLoss()
75-
self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
76-
self.auroc = torchmetrics.AUROC(task="multiclass", num_classes=num_classes)
77-
self.precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes)
7877
else:
7978
self.loss_fct = nn.MSELoss()
8079

@@ -187,7 +186,7 @@ def training_step(self, batch, batch_idx): # type: ignore
187186
Tensor
188187
Training loss.
189188
"""
190-
cat_features, num_features, labels = batch
189+
num_features, cat_features, labels = batch
191190

192191
# Check if the model has a `penalty_forward` method
193192
if hasattr(self.base_model, "penalty_forward"):
@@ -200,18 +199,17 @@ def training_step(self, batch, batch_idx): # type: ignore
200199
# Log the training loss
201200
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
202201

203-
# Log additional metrics
204-
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
205-
if self.num_classes > 1:
206-
acc = self.acc(preds, labels)
207-
self.log(
208-
"train_acc",
209-
acc,
210-
on_step=True,
211-
on_epoch=True,
212-
prog_bar=True,
213-
logger=True,
214-
)
202+
# Log custom training metrics
203+
for metric_name, metric_fn in self.train_metrics.items():
204+
metric_value = metric_fn(preds, labels)
205+
self.log(
206+
f"train_{metric_name}",
207+
metric_value,
208+
on_step=True,
209+
on_epoch=True,
210+
prog_bar=True,
211+
logger=True,
212+
)
215213

216214
return loss
217215

@@ -231,7 +229,7 @@ def validation_step(self, batch, batch_idx): # type: ignore
231229
Validation loss.
232230
"""
233231

234-
cat_features, num_features, labels = batch
232+
num_features, cat_features, labels = batch
235233
preds = self(num_features=num_features, cat_features=cat_features)
236234
val_loss = self.compute_loss(preds, labels)
237235

@@ -244,18 +242,17 @@ def validation_step(self, batch, batch_idx): # type: ignore
244242
logger=True,
245243
)
246244

247-
# Log additional metrics
248-
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
249-
if self.num_classes > 1:
250-
acc = self.acc(preds, labels)
251-
self.log(
252-
"val_acc",
253-
acc,
254-
on_step=False,
255-
on_epoch=True,
256-
prog_bar=True,
257-
logger=True,
258-
)
245+
# Log custom validation metrics
246+
for metric_name, metric_fn in self.val_metrics.items():
247+
metric_value = metric_fn(preds, labels)
248+
self.log(
249+
f"val_{metric_name}",
250+
metric_value,
251+
on_step=False,
252+
on_epoch=True,
253+
prog_bar=True,
254+
logger=True,
255+
)
259256

260257
return val_loss
261258

@@ -274,7 +271,7 @@ def test_step(self, batch, batch_idx): # type: ignore
274271
Tensor
275272
Test loss.
276273
"""
277-
cat_features, num_features, labels = batch
274+
num_features, cat_features, labels = batch
278275
preds = self(num_features=num_features, cat_features=cat_features)
279276
test_loss = self.compute_loss(preds, labels)
280277

@@ -287,21 +284,29 @@ def test_step(self, batch, batch_idx): # type: ignore
287284
logger=True,
288285
)
289286

290-
# Log additional metrics
291-
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
292-
if self.num_classes > 1:
293-
acc = self.acc(preds, labels)
294-
self.log(
295-
"test_acc",
296-
acc,
297-
on_step=False,
298-
on_epoch=True,
299-
prog_bar=True,
300-
logger=True,
301-
)
302-
303287
return test_loss
304288

289+
def predict_step(self, batch, batch_idx):
290+
"""Predict step for a single batch.
291+
292+
Parameters
293+
----------
294+
batch : tuple
295+
Batch of data containing numerical features, categorical features, and labels.
296+
batch_idx : int
297+
Index of the batch.
298+
299+
Returns
300+
-------
301+
Tensor
302+
Predictions.
303+
"""
304+
305+
num_features, cat_features = batch
306+
preds = self(num_features=num_features, cat_features=cat_features)
307+
308+
return preds
309+
305310
def on_validation_epoch_end(self):
306311
"""Callback executed at the end of each validation epoch.
307312

0 commit comments

Comments
 (0)