@@ -1111,8 +1111,208 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
11111111 focal_loss = self .focal (input , target )
11121112 total_loss : torch .Tensor = self .lambda_gdl * gdl_loss + self .lambda_focal * focal_loss
11131113 return total_loss
1114+
1115+ class SparseDiceLoss (_Loss ):
1116+ """
1117+ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.
1118+ The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
1119+
1120+ Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
1121+ must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
1122+ can be 1 or N (one-hot format).
11141123
1124+ The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
1125+ the inter-over-union calculation to smooth results respectively, these values should be small.
1126+
1127+ The original papers:
1128+
1129+ Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
1130+ Medical Image Segmentation. 3DV 2016.
1131+
1132+ Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
1133+ Soft Labels. NeurIPS 2023.
1134+
1135+ Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
1136+ Soft Labels. MICCAI 2023.
1137+
1138+ """
1139+
1140+ def __init__ (
1141+ self ,
1142+ include_background : bool = True ,
1143+ to_onehot_y : bool = False ,
1144+ sigmoid : bool = False ,
1145+ softmax : bool = False ,
1146+ other_act : Callable | None = None ,
1147+ squared_pred : bool = False ,
1148+ jaccard : bool = False ,
1149+ reduction : LossReduction | str = LossReduction .MEAN ,
1150+ smooth_nr : float = 1e-5 ,
1151+ smooth_dr : float = 1e-5 ,
1152+ batch : bool = False ,
1153+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
1154+ soft_label : bool = False ,
1155+ ) -> None :
1156+ """
1157+ Args:
1158+ include_background: if False, channel index 0 (background category) is excluded from the calculation.
1159+ if the non-background segmentations are small compared to the total image size they can get overwhelmed
1160+ by the signal from the background so excluding it in such cases helps convergence.
1161+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
1162+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
1163+ sigmoid: if True, apply a sigmoid function to the prediction.
1164+ softmax: if True, apply a softmax function to the prediction.
1165+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
1166+ ``other_act = torch.tanh``.
1167+ squared_pred: use squared versions of targets and predictions in the denominator or not.
1168+ jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
1169+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
1170+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
1171+
1172+ - ``"none"``: no reduction will be applied.
1173+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
1174+ - ``"sum"``: the output will be summed.
1175+
1176+ smooth_nr: a small constant added to the numerator to avoid zero.
1177+ smooth_dr: a small constant added to the denominator to avoid nan.
1178+ batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
1179+ Defaults to False, a Dice loss value is computed independently from each item in the batch
1180+ before any `reduction`.
1181+ weight: weights to apply to the voxels of each class. If None no weights are applied.
1182+ The input can be a single value (same weight for all classes), a sequence of values (the length
1183+ of the sequence should be the same as the number of classes. If not ``include_background``,
1184+ the number of classes should not include the background category class 0).
1185+ The value/values should be no less than 0. Defaults to None.
1186+ soft_label: whether the target contains non-binary values (soft labels) or not.
1187+ If True a soft label formulation of the loss will be used.
1188+
1189+ Raises:
1190+ TypeError: When ``other_act`` is not an ``Optional[Callable]``.
1191+ ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
1192+ Incompatible values.
1193+
1194+ """
1195+ super ().__init__ (reduction = LossReduction (reduction ).value )
1196+ if other_act is not None and not callable (other_act ):
1197+ raise TypeError (f"other_act must be None or callable but is { type (other_act ).__name__ } ." )
1198+ if int (sigmoid ) + int (softmax ) + int (other_act is not None ) > 1 :
1199+ raise ValueError ("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None]." )
1200+ self .include_background = include_background
1201+ self .to_onehot_y = to_onehot_y
1202+ self .sigmoid = sigmoid
1203+ self .softmax = softmax
1204+ self .other_act = other_act
1205+ self .squared_pred = squared_pred
1206+ self .jaccard = jaccard
1207+ self .smooth_nr = float (smooth_nr )
1208+ self .smooth_dr = float (smooth_dr )
1209+ self .batch = batch
1210+ weight = torch .as_tensor (weight ) if weight is not None else None
1211+ self .register_buffer ("class_weight" , weight )
1212+ self .class_weight : None | torch .Tensor
1213+ self .soft_label = soft_label
1214+
1215+ def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
1216+ """
1217+ Args:
1218+ input: the shape should be BNH[WD], where N is the number of classes.
1219+ target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
1220+
1221+ Raises:
1222+ AssertionError: When input and target (after one hot transform if set)
1223+ have different shapes.
1224+ ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
1225+
1226+ Example:
1227+ >>> from monai.losses.dice import * # NOQA
1228+ >>> import torch
1229+ >>> from monai.losses.dice import DiceLoss
1230+ >>> B, C, H, W = 7, 5, 3, 2
1231+ >>> input = torch.rand(B, C, H, W)
1232+ >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
1233+ >>> target = one_hot(target_idx[:, None, ...], num_classes=C)
1234+ >>> self = DiceLoss(reduction='none')
1235+ >>> loss = self(input, target)
1236+ >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
1237+ """
1238+ if self .sigmoid :
1239+ input = torch .sigmoid (input )
1240+
1241+ n_pred_ch = input .shape [1 ]
1242+ if self .softmax :
1243+ if n_pred_ch == 1 :
1244+ warnings .warn ("single channel prediction, `softmax=True` ignored." )
1245+ else :
1246+ input = torch .softmax (input , 1 )
1247+
1248+ if self .other_act is not None :
1249+ input = self .other_act (input )
1250+
1251+ if self .to_onehot_y :
1252+ if n_pred_ch == 1 :
1253+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
1254+ else :
1255+ target = one_hot (target , num_classes = n_pred_ch )
1256+
1257+ if not self .include_background :
1258+ if n_pred_ch == 1 :
1259+ warnings .warn ("single channel prediction, `include_background=False` ignored." )
1260+ else :
1261+ # if skipping background, removing first channel
1262+ target = target [:, 1 :]
1263+ input = input [:, 1 :]
1264+
1265+ if target .shape != input .shape :
1266+ raise AssertionError (f"ground truth has different shape ({ target .shape } ) from input ({ input .shape } )" )
1267+
1268+ # reducing only spatial dimensions (not batch nor channels)
1269+ reduce_axis : list [int ] = torch .arange (2 , len (input .shape )).tolist ()
1270+ if self .batch :
1271+ # reducing spatial dimensions and batch
1272+ reduce_axis = [0 ] + reduce_axis
1273+
1274+ ord = 2 if self .squared_pred else 1
1275+ tp , fp , fn = compute_tp_fp_fn (input , target , reduce_axis , ord , self .soft_label )
1276+ if not self .jaccard :
1277+ fp *= 0.5
1278+ fn *= 0.5
1279+ numerator = 2 * tp + self .smooth_nr
1280+ denominator = 2 * (tp + fp + fn ) + self .smooth_dr
1281+
1282+ f : torch .Tensor = 1 - numerator / denominator
1283+
1284+ num_of_classes = target .shape [1 ]
1285+ if self .class_weight is not None and num_of_classes != 1 :
1286+ # make sure the lengths of weights are equal to the number of classes
1287+ if self .class_weight .ndim == 0 :
1288+ self .class_weight = torch .as_tensor ([self .class_weight ] * num_of_classes )
1289+ else :
1290+ if self .class_weight .shape [0 ] != num_of_classes :
1291+ raise ValueError (
1292+ """the length of the `weight` sequence should be the same as the number of classes.
1293+ If `include_background=False`, the weight should not include
1294+ the background category class 0."""
1295+ )
1296+ if self .class_weight .min () < 0 :
1297+ raise ValueError ("the value/values of the `weight` should be no less than 0." )
1298+ # apply class_weight to loss
1299+ f = f * self .class_weight .to (f )
1300+
1301+ if self .reduction == LossReduction .MEAN .value :
1302+ f = torch .mean (f ) # the batch and channel average
1303+ elif self .reduction == LossReduction .SUM .value :
1304+ f = torch .sum (f ) # sum over the batch and channel dims
1305+ elif self .reduction == LossReduction .NONE .value :
1306+ # If we are not computing voxelwise loss components at least
1307+ # make sure a none reduction maintains a broadcastable shape
1308+ broadcast_shape = list (f .shape [0 :2 ]) + [1 ] * (len (input .shape ) - 2 )
1309+ f = f .view (broadcast_shape )
1310+ else :
1311+ raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
1312+
1313+ return f
11151314
1315+ sparse_dice_loss = SparseDiceLoss
11161316Dice = DiceLoss
11171317dice_ce = DiceCELoss
11181318dice_focal = DiceFocalLoss
0 commit comments