Skip to content
Merged
59 changes: 56 additions & 3 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import warnings
from collections.abc import Callable, Sequence
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -239,11 +238,52 @@ class MaskedDiceLoss(DiceLoss):

"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
) -> None:
"""
Args follow :py:class:`monai.losses.DiceLoss`.
"""
Comment thread
ericspod marked this conversation as resolved.
super().__init__(*args, **kwargs)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if sigmoid and softmax:
raise ValueError("Incompatible values: sigmoid=True and softmax=True.")
if other_act is not None and (sigmoid or softmax):
raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.")

self.pre_sigmoid = sigmoid
self.pre_softmax = softmax
self.pre_other_act = other_act
Comment thread
ytl0623 marked this conversation as resolved.

super().__init__(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=False,
softmax=False,
other_act=None,
squared_pred=squared_pred,
jaccard=jaccard,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
weight=weight,
soft_label=soft_label,
)

self.spatial_weighted = MaskedLoss(loss=super().forward)

def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
Expand All @@ -253,6 +293,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
target: the shape should be BNH[WD].
mask: the shape should B1H[WD] or 11H[WD].
"""

if self.pre_sigmoid:
input = torch.sigmoid(input)

n_pred_ch = input.shape[1]
if self.pre_softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

if self.pre_other_act is not None:
input = self.pre_other_act(input)
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]


Expand Down
6 changes: 3 additions & 3 deletions tests/losses/test_masked_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),
},
0.500,
0.333333,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
Expand All @@ -36,7 +36,7 @@
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
"mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),
},
0.422969,
0.301128,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
Expand All @@ -54,7 +54,7 @@
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
"mask": torch.tensor([[[1.0, 1.0, 0.0]]]),
},
0.47033,
0.579184,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand Down
Loading