33import lightning as pl
44import torch
55import torch .nn as nn
6- import torchmetrics
76
87
98class TaskModel (pl .LightningModule ):
@@ -41,6 +40,8 @@ def __init__(
4140 pruning_epoch = 5 ,
4241 optimizer_type : str = "Adam" ,
4342 optimizer_args : dict | None = None ,
43+ train_metrics : dict [str , Callable ] | None = None ,
44+ val_metrics : dict [str , Callable ] | None = None ,
4445 ** kwargs ,
4546 ):
4647 super ().__init__ ()
@@ -53,6 +54,10 @@ def __init__(
5354 self .pruning_epoch = pruning_epoch
5455 self .val_losses = []
5556
57+ # Store custom metrics
58+ self .train_metrics = train_metrics or {}
59+ self .val_metrics = val_metrics or {}
60+
5661 self .optimizer_params = {
5762 k .replace ("optimizer_" , "" ): v
5863 for k , v in optimizer_args .items () # type: ignore
@@ -65,16 +70,10 @@ def __init__(
6570 if num_classes == 2 :
6671 if not self .loss_fct :
6772 self .loss_fct = nn .BCEWithLogitsLoss ()
68- self .acc = torchmetrics .Accuracy (task = "binary" )
69- self .auroc = torchmetrics .AUROC (task = "binary" )
70- self .precision = torchmetrics .Precision (task = "binary" )
7173 self .num_classes = 1
7274 elif num_classes > 2 :
7375 if not self .loss_fct :
7476 self .loss_fct = nn .CrossEntropyLoss ()
75- self .acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = num_classes )
76- self .auroc = torchmetrics .AUROC (task = "multiclass" , num_classes = num_classes )
77- self .precision = torchmetrics .Precision (task = "multiclass" , num_classes = num_classes )
7877 else :
7978 self .loss_fct = nn .MSELoss ()
8079
@@ -187,7 +186,7 @@ def training_step(self, batch, batch_idx): # type: ignore
187186 Tensor
188187 Training loss.
189188 """
190- cat_features , num_features , labels = batch
189+ num_features , cat_features , labels = batch
191190
192191 # Check if the model has a `penalty_forward` method
193192 if hasattr (self .base_model , "penalty_forward" ):
@@ -200,18 +199,17 @@ def training_step(self, batch, batch_idx): # type: ignore
200199 # Log the training loss
201200 self .log ("train_loss" , loss , on_step = True , on_epoch = True , prog_bar = True , logger = True )
202201
203- # Log additional metrics
204- if not self .lss and not hasattr (self .base_model , "returns_ensemble" ):
205- if self .num_classes > 1 :
206- acc = self .acc (preds , labels )
207- self .log (
208- "train_acc" ,
209- acc ,
210- on_step = True ,
211- on_epoch = True ,
212- prog_bar = True ,
213- logger = True ,
214- )
202+ # Log custom training metrics
203+ for metric_name , metric_fn in self .train_metrics .items ():
204+ metric_value = metric_fn (preds , labels )
205+ self .log (
206+ f"train_{ metric_name } " ,
207+ metric_value ,
208+ on_step = True ,
209+ on_epoch = True ,
210+ prog_bar = True ,
211+ logger = True ,
212+ )
215213
216214 return loss
217215
@@ -231,7 +229,7 @@ def validation_step(self, batch, batch_idx): # type: ignore
231229 Validation loss.
232230 """
233231
234- cat_features , num_features , labels = batch
232+ num_features , cat_features , labels = batch
235233 preds = self (num_features = num_features , cat_features = cat_features )
236234 val_loss = self .compute_loss (preds , labels )
237235
@@ -244,18 +242,17 @@ def validation_step(self, batch, batch_idx): # type: ignore
244242 logger = True ,
245243 )
246244
247- # Log additional metrics
248- if not self .lss and not hasattr (self .base_model , "returns_ensemble" ):
249- if self .num_classes > 1 :
250- acc = self .acc (preds , labels )
251- self .log (
252- "val_acc" ,
253- acc ,
254- on_step = False ,
255- on_epoch = True ,
256- prog_bar = True ,
257- logger = True ,
258- )
245+ # Log custom validation metrics
246+ for metric_name , metric_fn in self .val_metrics .items ():
247+ metric_value = metric_fn (preds , labels )
248+ self .log (
249+ f"val_{ metric_name } " ,
250+ metric_value ,
251+ on_step = False ,
252+ on_epoch = True ,
253+ prog_bar = True ,
254+ logger = True ,
255+ )
259256
260257 return val_loss
261258
@@ -274,7 +271,7 @@ def test_step(self, batch, batch_idx): # type: ignore
274271 Tensor
275272 Test loss.
276273 """
277- cat_features , num_features , labels = batch
274+ num_features , cat_features , labels = batch
278275 preds = self (num_features = num_features , cat_features = cat_features )
279276 test_loss = self .compute_loss (preds , labels )
280277
@@ -287,21 +284,29 @@ def test_step(self, batch, batch_idx): # type: ignore
287284 logger = True ,
288285 )
289286
290- # Log additional metrics
291- if not self .lss and not hasattr (self .base_model , "returns_ensemble" ):
292- if self .num_classes > 1 :
293- acc = self .acc (preds , labels )
294- self .log (
295- "test_acc" ,
296- acc ,
297- on_step = False ,
298- on_epoch = True ,
299- prog_bar = True ,
300- logger = True ,
301- )
302-
303287 return test_loss
304288
289+ def predict_step (self , batch , batch_idx ):
290+ """Predict step for a single batch.
291+
292+ Parameters
293+ ----------
294+ batch : tuple
295+ Batch of data containing numerical features, categorical features, and labels.
296+ batch_idx : int
297+ Index of the batch.
298+
299+ Returns
300+ -------
301+ Tensor
302+ Predictions.
303+ """
304+
305+ num_features , cat_features = batch
306+ preds = self (num_features = num_features , cat_features = cat_features )
307+
308+ return preds
309+
305310 def on_validation_epoch_end (self ):
306311 """Callback executed at the end of each validation epoch.
307312
0 commit comments