@@ -101,8 +101,8 @@ def __init__(
101101 The value/values should be no less than 0. Defaults to None.
102102 soft_label: whether the target contains non-binary values (soft labels) or not.
103103 If True a soft label formulation of the loss will be used.
104- ignore_index: if not None, specifies a target index that is ignored and does not contribute to
105- the input gradient. Defaults to None.
104+ ignore_index: class index to ignore from the loss computation.
105+
106106 Raises:
107107 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
108108 ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
@@ -142,7 +142,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
142142 ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
143143
144144 Example:
145- >>> from monai.losses.dice import * # NOQA
145+ >>> from monai.losses.dice import * # NOQA
146146 >>> import torch
147147 >>> from monai.losses.dice import DiceLoss
148148 >>> B, C, H, W = 7, 5, 3, 2
@@ -166,7 +166,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
166166 if self .other_act is not None :
167167 input = self .other_act (input )
168168
169- # mask the ignore_index if specified, must be done before one_hot
170169 mask : torch .Tensor | None = None
171170 if self .ignore_index is not None :
172171 mask = (target != self .ignore_index ).float ()
@@ -343,7 +342,6 @@ def __init__(
343342 smooth_dr : float = 1e-5 ,
344343 batch : bool = False ,
345344 soft_label : bool = False ,
346- ignore_index : int | None = None ,
347345 ) -> None :
348346 """
349347 Args:
@@ -369,8 +367,6 @@ def __init__(
369367 If True, the class-weighted intersection and union areas are first summed across the batches.
370368 soft_label: whether the target contains non-binary values (soft labels) or not.
371369 If True a soft label formulation of the loss will be used.
372- ignore_index: if not None, specifies a target index that is ignored and does not contribute to
373- the input gradient.
374370
375371 Raises:
376372 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -396,7 +392,6 @@ def __init__(
396392 self .smooth_dr = float (smooth_dr )
397393 self .batch = batch
398394 self .soft_label = soft_label
399- self .ignore_index = ignore_index
400395
401396 def w_func (self , grnd ):
402397 if self .w_type == str (Weight .SIMPLE ):
@@ -427,11 +422,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
427422 if self .other_act is not None :
428423 input = self .other_act (input )
429424
430- # Prepare mask before potential one-hot conversion
431- mask : torch .Tensor | None = None
432- if self .ignore_index is not None :
433- mask = (target != self .ignore_index ).float ()
434-
435425 if self .to_onehot_y :
436426 if n_pred_ch == 1 :
437427 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
@@ -442,17 +432,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
442432 if n_pred_ch == 1 :
443433 warnings .warn ("single channel prediction, `include_background=False` ignored." )
444434 else :
435+ # if skipping background, removing first channel
445436 target = target [:, 1 :]
446437 input = input [:, 1 :]
447438
448439 if target .shape != input .shape :
449440 raise AssertionError (f"ground truth has differing shape ({ target .shape } ) from input ({ input .shape } )" )
450441
451- # Exclude ignored regions from calculations
452- if mask is not None :
453- input = input * mask
454- target = target * mask
455-
442+ # reducing only spatial dimensions (not batch nor channels)
456443 reduce_axis : list [int ] = torch .arange (2 , len (input .shape )).tolist ()
457444 if self .batch :
458445 reduce_axis = [0 ] + reduce_axis
@@ -479,10 +466,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
479466 f : torch .Tensor = 1.0 - (numer / denom )
480467
481468 if self .reduction == LossReduction .MEAN .value :
482- f = torch .mean (f )
469+ f = torch .mean (f ) # the batch and channel average
483470 elif self .reduction == LossReduction .SUM .value :
484- f = torch .sum (f )
471+ f = torch .sum (f ) # sum over the batch and channel dims
485472 elif self .reduction == LossReduction .NONE .value :
473+ # If we are not computing voxelwise loss components at least
474+ # make sure a none reduction maintains a broadcastable shape
486475 broadcast_shape = list (f .shape [0 :2 ]) + [1 ] * (len (input .shape ) - 2 )
487476 f = f .view (broadcast_shape )
488477 else :
@@ -515,12 +504,11 @@ def __init__(
515504 reduction : LossReduction | str = LossReduction .MEAN ,
516505 smooth_nr : float = 1e-5 ,
517506 smooth_dr : float = 1e-5 ,
518- ignore_index : int | None = None ,
519507 ) -> None :
520508 """
521509 Args:
522510 dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
523- It must have dimension C x C where C is the number of classes.
511+ It must have dimension C x C where C is the number of classes.
524512 weighting_mode: {``"default"``, ``"GDL"``}
525513 Specifies how to weight the class-specific sum of errors.
526514 Default to ``"default"``.
@@ -540,11 +528,27 @@ def __init__(
540528 - ``"sum"``: the output will be summed.
541529 smooth_nr: a small constant added to the numerator to avoid zero.
542530 smooth_dr: a small constant added to the denominator to avoid nan.
543- ignore_index: if not None, specifies a target index that is ignored and does not contribute to
544- the input gradient.
545531
546532 Raises:
547533 ValueError: When ``dist_matrix`` is not a square matrix.
534+
535+ Example:
536+ .. code-block:: python
537+
538+ import torch
539+ import numpy as np
540+ from monai.losses import GeneralizedWassersteinDiceLoss
541+
542+ # Example with 3 classes (including the background: label 0).
543+ # The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
544+ # The distance between class 1 and class 2 is 0.5.
545+ dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
546+ wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
547+
548+ pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
549+ grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
550+ wass_loss(pred_score, grnd) # 0
551+
548552 """
549553 super ().__init__ (reduction = LossReduction (reduction ).value )
550554
@@ -563,13 +567,13 @@ def __init__(
563567 self .num_classes = self .m .size (0 )
564568 self .smooth_nr = float (smooth_nr )
565569 self .smooth_dr = float (smooth_dr )
566- self .ignore_index = ignore_index
567570
568571 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
569572 """
570573 Args:
571574 input: the shape should be BNH[WD].
572575 target: the shape should be BNH[WD].
576+
573577 """
574578 # Aggregate spatial dimensions
575579 flat_input = input .reshape (input .size (0 ), input .size (1 ), - 1 )
@@ -581,20 +585,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
581585 # Compute the Wasserstein distance map
582586 wass_dist_map = self .wasserstein_distance_map (probs , flat_target )
583587
584- # Apply masking for ignore_index
585- if self .ignore_index is not None :
586- mask = (flat_target != self .ignore_index ).float ()
587- wass_dist_map = wass_dist_map * mask
588-
589588 # Compute the values of alpha to use
590589 alpha = self ._compute_alpha_generalized_true_positives (flat_target )
591590
592591 # Compute the numerator and denominator of the generalized Wasserstein Dice loss
593592 if self .alpha_mode == "GDL" :
594593 # use GDL-style alpha weights (i.e. normalize by the volume of each class)
594+ # contrary to the original definition we also use alpha in the "generalized all error".
595595 true_pos = self ._compute_generalized_true_positive (alpha , flat_target , wass_dist_map )
596596 denom = self ._compute_denominator (alpha , flat_target , wass_dist_map )
597597 else : # default: as in the original paper
598+ # (i.e. alpha=1 for all foreground classes and 0 for the background).
599+ # Compute the generalised number of true positives
598600 true_pos = self ._compute_generalized_true_positive (alpha , flat_target , wass_dist_map )
599601 all_error = torch .sum (wass_dist_map , dim = 1 )
600602 denom = 2 * true_pos + all_error
@@ -604,10 +606,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
604606 wass_dice_loss : torch .Tensor = 1.0 - wass_dice
605607
606608 if self .reduction == LossReduction .MEAN .value :
607- wass_dice_loss = torch .mean (wass_dice_loss )
609+ wass_dice_loss = torch .mean (wass_dice_loss ) # the batch and channel average
608610 elif self .reduction == LossReduction .SUM .value :
609- wass_dice_loss = torch .sum (wass_dice_loss )
611+ wass_dice_loss = torch .sum (wass_dice_loss ) # sum over the batch and channel dims
610612 elif self .reduction == LossReduction .NONE .value :
613+ # GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
611614 pass
612615 else :
613616 raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
@@ -733,7 +736,6 @@ def __init__(
733736 lambda_dice : float = 1.0 ,
734737 lambda_ce : float = 1.0 ,
735738 label_smoothing : float = 0.0 ,
736- ignore_index : int | None = None ,
737739 ) -> None :
738740 """
739741 Args:
@@ -775,8 +777,6 @@ def __init__(
775777 label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
776778 by the given factor to reduce overfitting.
777779 Defaults to 0.0.
778- ignore_index: if not None, specifies a target index that is ignored and does not contribute to
779- the input gradient.
780780
781781 """
782782 super ().__init__ ()
@@ -799,22 +799,15 @@ def __init__(
799799 smooth_dr = smooth_dr ,
800800 batch = batch ,
801801 weight = dice_weight ,
802- ignore_index = ignore_index ,
803- )
804- self .cross_entropy = nn .CrossEntropyLoss (
805- weight = weight ,
806- reduction = reduction ,
807- label_smoothing = label_smoothing ,
808- ignore_index = ignore_index if ignore_index is not None else - 100 ,
809802 )
803+ self .cross_entropy = nn .CrossEntropyLoss (weight = weight , reduction = reduction , label_smoothing = label_smoothing )
810804 self .binary_cross_entropy = nn .BCEWithLogitsLoss (pos_weight = weight , reduction = reduction )
811805 if lambda_dice < 0.0 :
812806 raise ValueError ("lambda_dice should be no less than 0.0." )
813807 if lambda_ce < 0.0 :
814808 raise ValueError ("lambda_ce should be no less than 0.0." )
815809 self .lambda_dice = lambda_dice
816810 self .lambda_ce = lambda_ce
817- self .ignore_index = ignore_index
818811
819812 def ce (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
820813 """
@@ -870,21 +863,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
870863 )
871864
872865 dice_loss = self .dice (input , target )
873-
874- if input .shape [1 ] != 1 :
875- # CrossEntropyLoss handles ignore_index natively
876- ce_loss = self .ce (input , target )
877- else :
878- # BCEWithLogitsLoss does not support ignore_index, handle manually
879- ce_loss = self .bce (input , target )
880- if self .ignore_index is not None :
881- mask = (target != self .ignore_index ).float ()
882- ce_loss = ce_loss * mask
883- if self .dice .reduction == "mean" :
884- ce_loss = torch .mean (ce_loss )
885- elif self .dice .reduction == "sum" :
886- ce_loss = torch .sum (ce_loss )
887-
866+ ce_loss = self .ce (input , target ) if input .shape [1 ] != 1 else self .bce (input , target )
888867 total_loss : torch .Tensor = self .lambda_dice * dice_loss + self .lambda_ce * ce_loss
889868
890869 return total_loss
0 commit comments