Skip to content

Commit f01cbc4

Browse files
committed
fix: resolve shape issues and CI fails
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent a1f6ef4 commit f01cbc4

2 files changed

Lines changed: 94 additions & 117 deletions

File tree

monai/losses/dice.py

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def __init__(
101101
The value/values should be no less than 0. Defaults to None.
102102
soft_label: whether the target contains non-binary values (soft labels) or not.
103103
If True a soft label formulation of the loss will be used.
104-
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
105-
the input gradient. Defaults to None.
104+
ignore_index: class index to ignore from the loss computation.
105+
106106
Raises:
107107
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
108108
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
@@ -142,7 +142,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
142142
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
143143
144144
Example:
145-
>>> from monai.losses.dice import * # NOQA
145+
>>> from monai.losses.dice import * # NOQA
146146
>>> import torch
147147
>>> from monai.losses.dice import DiceLoss
148148
>>> B, C, H, W = 7, 5, 3, 2
@@ -166,7 +166,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
166166
if self.other_act is not None:
167167
input = self.other_act(input)
168168

169-
# mask the ignore_index if specified, must be done before one_hot
170169
mask: torch.Tensor | None = None
171170
if self.ignore_index is not None:
172171
mask = (target != self.ignore_index).float()
@@ -343,7 +342,6 @@ def __init__(
343342
smooth_dr: float = 1e-5,
344343
batch: bool = False,
345344
soft_label: bool = False,
346-
ignore_index: int | None = None,
347345
) -> None:
348346
"""
349347
Args:
@@ -369,8 +367,6 @@ def __init__(
369367
If True, the class-weighted intersection and union areas are first summed across the batches.
370368
soft_label: whether the target contains non-binary values (soft labels) or not.
371369
If True a soft label formulation of the loss will be used.
372-
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
373-
the input gradient.
374370
375371
Raises:
376372
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -396,7 +392,6 @@ def __init__(
396392
self.smooth_dr = float(smooth_dr)
397393
self.batch = batch
398394
self.soft_label = soft_label
399-
self.ignore_index = ignore_index
400395

401396
def w_func(self, grnd):
402397
if self.w_type == str(Weight.SIMPLE):
@@ -427,11 +422,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
427422
if self.other_act is not None:
428423
input = self.other_act(input)
429424

430-
# Prepare mask before potential one-hot conversion
431-
mask: torch.Tensor | None = None
432-
if self.ignore_index is not None:
433-
mask = (target != self.ignore_index).float()
434-
435425
if self.to_onehot_y:
436426
if n_pred_ch == 1:
437427
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
@@ -442,17 +432,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
442432
if n_pred_ch == 1:
443433
warnings.warn("single channel prediction, `include_background=False` ignored.")
444434
else:
435+
# if skipping background, removing first channel
445436
target = target[:, 1:]
446437
input = input[:, 1:]
447438

448439
if target.shape != input.shape:
449440
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
450441

451-
# Exclude ignored regions from calculations
452-
if mask is not None:
453-
input = input * mask
454-
target = target * mask
455-
442+
# reducing only spatial dimensions (not batch nor channels)
456443
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
457444
if self.batch:
458445
reduce_axis = [0] + reduce_axis
@@ -479,10 +466,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
479466
f: torch.Tensor = 1.0 - (numer / denom)
480467

481468
if self.reduction == LossReduction.MEAN.value:
482-
f = torch.mean(f)
469+
f = torch.mean(f) # the batch and channel average
483470
elif self.reduction == LossReduction.SUM.value:
484-
f = torch.sum(f)
471+
f = torch.sum(f) # sum over the batch and channel dims
485472
elif self.reduction == LossReduction.NONE.value:
473+
# If we are not computing voxelwise loss components at least
474+
# make sure a none reduction maintains a broadcastable shape
486475
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
487476
f = f.view(broadcast_shape)
488477
else:
@@ -515,12 +504,11 @@ def __init__(
515504
reduction: LossReduction | str = LossReduction.MEAN,
516505
smooth_nr: float = 1e-5,
517506
smooth_dr: float = 1e-5,
518-
ignore_index: int | None = None,
519507
) -> None:
520508
"""
521509
Args:
522510
dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
523-
It must have dimension C x C where C is the number of classes.
511+
It must have dimension C x C where C is the number of classes.
524512
weighting_mode: {``"default"``, ``"GDL"``}
525513
Specifies how to weight the class-specific sum of errors.
526514
Default to ``"default"``.
@@ -540,11 +528,27 @@ def __init__(
540528
- ``"sum"``: the output will be summed.
541529
smooth_nr: a small constant added to the numerator to avoid zero.
542530
smooth_dr: a small constant added to the denominator to avoid nan.
543-
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
544-
the input gradient.
545531
546532
Raises:
547533
ValueError: When ``dist_matrix`` is not a square matrix.
534+
535+
Example:
536+
.. code-block:: python
537+
538+
import torch
539+
import numpy as np
540+
from monai.losses import GeneralizedWassersteinDiceLoss
541+
542+
# Example with 3 classes (including the background: label 0).
543+
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
544+
# The distance between class 1 and class 2 is 0.5.
545+
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
546+
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
547+
548+
pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
549+
grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
550+
wass_loss(pred_score, grnd) # 0
551+
548552
"""
549553
super().__init__(reduction=LossReduction(reduction).value)
550554

@@ -563,13 +567,13 @@ def __init__(
563567
self.num_classes = self.m.size(0)
564568
self.smooth_nr = float(smooth_nr)
565569
self.smooth_dr = float(smooth_dr)
566-
self.ignore_index = ignore_index
567570

568571
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
569572
"""
570573
Args:
571574
input: the shape should be BNH[WD].
572575
target: the shape should be BNH[WD].
576+
573577
"""
574578
# Aggregate spatial dimensions
575579
flat_input = input.reshape(input.size(0), input.size(1), -1)
@@ -581,20 +585,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
581585
# Compute the Wasserstein distance map
582586
wass_dist_map = self.wasserstein_distance_map(probs, flat_target)
583587

584-
# Apply masking for ignore_index
585-
if self.ignore_index is not None:
586-
mask = (flat_target != self.ignore_index).float()
587-
wass_dist_map = wass_dist_map * mask
588-
589588
# Compute the values of alpha to use
590589
alpha = self._compute_alpha_generalized_true_positives(flat_target)
591590

592591
# Compute the numerator and denominator of the generalized Wasserstein Dice loss
593592
if self.alpha_mode == "GDL":
594593
# use GDL-style alpha weights (i.e. normalize by the volume of each class)
594+
# contrary to the original definition we also use alpha in the "generalized all error".
595595
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
596596
denom = self._compute_denominator(alpha, flat_target, wass_dist_map)
597597
else: # default: as in the original paper
598+
# (i.e. alpha=1 for all foreground classes and 0 for the background).
599+
# Compute the generalised number of true positives
598600
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
599601
all_error = torch.sum(wass_dist_map, dim=1)
600602
denom = 2 * true_pos + all_error
@@ -604,10 +606,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
604606
wass_dice_loss: torch.Tensor = 1.0 - wass_dice
605607

606608
if self.reduction == LossReduction.MEAN.value:
607-
wass_dice_loss = torch.mean(wass_dice_loss)
609+
wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average
608610
elif self.reduction == LossReduction.SUM.value:
609-
wass_dice_loss = torch.sum(wass_dice_loss)
611+
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
610612
elif self.reduction == LossReduction.NONE.value:
613+
# GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
611614
pass
612615
else:
613616
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
@@ -733,7 +736,6 @@ def __init__(
733736
lambda_dice: float = 1.0,
734737
lambda_ce: float = 1.0,
735738
label_smoothing: float = 0.0,
736-
ignore_index: int | None = None,
737739
) -> None:
738740
"""
739741
Args:
@@ -775,8 +777,6 @@ def __init__(
775777
label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
776778
by the given factor to reduce overfitting.
777779
Defaults to 0.0.
778-
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
779-
the input gradient.
780780
781781
"""
782782
super().__init__()
@@ -799,22 +799,15 @@ def __init__(
799799
smooth_dr=smooth_dr,
800800
batch=batch,
801801
weight=dice_weight,
802-
ignore_index=ignore_index,
803-
)
804-
self.cross_entropy = nn.CrossEntropyLoss(
805-
weight=weight,
806-
reduction=reduction,
807-
label_smoothing=label_smoothing,
808-
ignore_index=ignore_index if ignore_index is not None else -100,
809802
)
803+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
810804
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
811805
if lambda_dice < 0.0:
812806
raise ValueError("lambda_dice should be no less than 0.0.")
813807
if lambda_ce < 0.0:
814808
raise ValueError("lambda_ce should be no less than 0.0.")
815809
self.lambda_dice = lambda_dice
816810
self.lambda_ce = lambda_ce
817-
self.ignore_index = ignore_index
818811

819812
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
820813
"""
@@ -870,21 +863,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
870863
)
871864

872865
dice_loss = self.dice(input, target)
873-
874-
if input.shape[1] != 1:
875-
# CrossEntropyLoss handles ignore_index natively
876-
ce_loss = self.ce(input, target)
877-
else:
878-
# BCEWithLogitsLoss does not support ignore_index, handle manually
879-
ce_loss = self.bce(input, target)
880-
if self.ignore_index is not None:
881-
mask = (target != self.ignore_index).float()
882-
ce_loss = ce_loss * mask
883-
if self.dice.reduction == "mean":
884-
ce_loss = torch.mean(ce_loss)
885-
elif self.dice.reduction == "sum":
886-
ce_loss = torch.sum(ce_loss)
887-
866+
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)
888867
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
889868

890869
return total_loss

0 commit comments

Comments
 (0)