Skip to content

Commit 28f99d3

Browse files
committed
Spice Dice Loss
Began implementation of functions for sparse dice loss, which is dice loss computed on sparse datasets. initial commit is to track issue #8731
1 parent 9ddd5e6 commit 28f99d3

2 files changed

Lines changed: 427 additions & 0 deletions

File tree

monai/losses/dice.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,8 +1111,208 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
11111111
focal_loss = self.focal(input, target)
11121112
total_loss: torch.Tensor = self.lambda_gdl * gdl_loss + self.lambda_focal * focal_loss
11131113
return total_loss
1114+
1115+
class SparseDiceLoss(_Loss):
1116+
"""
1117+
Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.
1118+
The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
1119+
1120+
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
1121+
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
1122+
can be 1 or N (one-hot format).
11141123
1124+
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
1125+
the inter-over-union calculation to smooth results respectively, these values should be small.
1126+
1127+
The original papers:
1128+
1129+
Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
1130+
Medical Image Segmentation. 3DV 2016.
1131+
1132+
Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
1133+
Soft Labels. NeurIPS 2023.
1134+
1135+
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
1136+
Soft Labels. MICCAI 2023.
1137+
1138+
"""
1139+
1140+
def __init__(
1141+
self,
1142+
include_background: bool = True,
1143+
to_onehot_y: bool = False,
1144+
sigmoid: bool = False,
1145+
softmax: bool = False,
1146+
other_act: Callable | None = None,
1147+
squared_pred: bool = False,
1148+
jaccard: bool = False,
1149+
reduction: LossReduction | str = LossReduction.MEAN,
1150+
smooth_nr: float = 1e-5,
1151+
smooth_dr: float = 1e-5,
1152+
batch: bool = False,
1153+
weight: Sequence[float] | float | int | torch.Tensor | None = None,
1154+
soft_label: bool = False,
1155+
) -> None:
1156+
"""
1157+
Args:
1158+
include_background: if False, channel index 0 (background category) is excluded from the calculation.
1159+
if the non-background segmentations are small compared to the total image size they can get overwhelmed
1160+
by the signal from the background so excluding it in such cases helps convergence.
1161+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
1162+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
1163+
sigmoid: if True, apply a sigmoid function to the prediction.
1164+
softmax: if True, apply a softmax function to the prediction.
1165+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
1166+
``other_act = torch.tanh``.
1167+
squared_pred: use squared versions of targets and predictions in the denominator or not.
1168+
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
1169+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
1170+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
1171+
1172+
- ``"none"``: no reduction will be applied.
1173+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
1174+
- ``"sum"``: the output will be summed.
1175+
1176+
smooth_nr: a small constant added to the numerator to avoid zero.
1177+
smooth_dr: a small constant added to the denominator to avoid nan.
1178+
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
1179+
Defaults to False, a Dice loss value is computed independently from each item in the batch
1180+
before any `reduction`.
1181+
weight: weights to apply to the voxels of each class. If None no weights are applied.
1182+
The input can be a single value (same weight for all classes), a sequence of values (the length
1183+
of the sequence should be the same as the number of classes. If not ``include_background``,
1184+
the number of classes should not include the background category class 0).
1185+
The value/values should be no less than 0. Defaults to None.
1186+
soft_label: whether the target contains non-binary values (soft labels) or not.
1187+
If True a soft label formulation of the loss will be used.
1188+
1189+
Raises:
1190+
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
1191+
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
1192+
Incompatible values.
1193+
1194+
"""
1195+
super().__init__(reduction=LossReduction(reduction).value)
1196+
if other_act is not None and not callable(other_act):
1197+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
1198+
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
1199+
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
1200+
self.include_background = include_background
1201+
self.to_onehot_y = to_onehot_y
1202+
self.sigmoid = sigmoid
1203+
self.softmax = softmax
1204+
self.other_act = other_act
1205+
self.squared_pred = squared_pred
1206+
self.jaccard = jaccard
1207+
self.smooth_nr = float(smooth_nr)
1208+
self.smooth_dr = float(smooth_dr)
1209+
self.batch = batch
1210+
weight = torch.as_tensor(weight) if weight is not None else None
1211+
self.register_buffer("class_weight", weight)
1212+
self.class_weight: None | torch.Tensor
1213+
self.soft_label = soft_label
1214+
1215+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1216+
"""
1217+
Args:
1218+
input: the shape should be BNH[WD], where N is the number of classes.
1219+
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
1220+
1221+
Raises:
1222+
AssertionError: When input and target (after one hot transform if set)
1223+
have different shapes.
1224+
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
1225+
1226+
Example:
1227+
>>> from monai.losses.dice import * # NOQA
1228+
>>> import torch
1229+
>>> from monai.losses.dice import DiceLoss
1230+
>>> B, C, H, W = 7, 5, 3, 2
1231+
>>> input = torch.rand(B, C, H, W)
1232+
>>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
1233+
>>> target = one_hot(target_idx[:, None, ...], num_classes=C)
1234+
>>> self = DiceLoss(reduction='none')
1235+
>>> loss = self(input, target)
1236+
>>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
1237+
"""
1238+
if self.sigmoid:
1239+
input = torch.sigmoid(input)
1240+
1241+
n_pred_ch = input.shape[1]
1242+
if self.softmax:
1243+
if n_pred_ch == 1:
1244+
warnings.warn("single channel prediction, `softmax=True` ignored.")
1245+
else:
1246+
input = torch.softmax(input, 1)
1247+
1248+
if self.other_act is not None:
1249+
input = self.other_act(input)
1250+
1251+
if self.to_onehot_y:
1252+
if n_pred_ch == 1:
1253+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
1254+
else:
1255+
target = one_hot(target, num_classes=n_pred_ch)
1256+
1257+
if not self.include_background:
1258+
if n_pred_ch == 1:
1259+
warnings.warn("single channel prediction, `include_background=False` ignored.")
1260+
else:
1261+
# if skipping background, removing first channel
1262+
target = target[:, 1:]
1263+
input = input[:, 1:]
1264+
1265+
if target.shape != input.shape:
1266+
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
1267+
1268+
# reducing only spatial dimensions (not batch nor channels)
1269+
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
1270+
if self.batch:
1271+
# reducing spatial dimensions and batch
1272+
reduce_axis = [0] + reduce_axis
1273+
1274+
ord = 2 if self.squared_pred else 1
1275+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
1276+
if not self.jaccard:
1277+
fp *= 0.5
1278+
fn *= 0.5
1279+
numerator = 2 * tp + self.smooth_nr
1280+
denominator = 2 * (tp + fp + fn) + self.smooth_dr
1281+
1282+
f: torch.Tensor = 1 - numerator / denominator
1283+
1284+
num_of_classes = target.shape[1]
1285+
if self.class_weight is not None and num_of_classes != 1:
1286+
# make sure the lengths of weights are equal to the number of classes
1287+
if self.class_weight.ndim == 0:
1288+
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
1289+
else:
1290+
if self.class_weight.shape[0] != num_of_classes:
1291+
raise ValueError(
1292+
"""the length of the `weight` sequence should be the same as the number of classes.
1293+
If `include_background=False`, the weight should not include
1294+
the background category class 0."""
1295+
)
1296+
if self.class_weight.min() < 0:
1297+
raise ValueError("the value/values of the `weight` should be no less than 0.")
1298+
# apply class_weight to loss
1299+
f = f * self.class_weight.to(f)
1300+
1301+
if self.reduction == LossReduction.MEAN.value:
1302+
f = torch.mean(f) # the batch and channel average
1303+
elif self.reduction == LossReduction.SUM.value:
1304+
f = torch.sum(f) # sum over the batch and channel dims
1305+
elif self.reduction == LossReduction.NONE.value:
1306+
# If we are not computing voxelwise loss components at least
1307+
# make sure a none reduction maintains a broadcastable shape
1308+
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
1309+
f = f.view(broadcast_shape)
1310+
else:
1311+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
1312+
1313+
return f
11151314

1315+
sparse_dice_loss = SparseDiceLoss
11161316
Dice = DiceLoss
11171317
dice_ce = DiceCELoss
11181318
dice_focal = DiceFocalLoss

0 commit comments

Comments
 (0)