1313
1414import torch
1515
16- from monai .metrics .utils import do_metric_reduction
16+ from monai .metrics .utils import compute_voronoi_regions_fast , do_metric_reduction
1717from monai .utils import MetricReduction , deprecated_arg
18+ from monai .utils .module import optional_import
1819
1920from .metric import CumulativeIterationMetric
2021
22+ scipy_ndimage , has_scipy_ndimage = optional_import ("scipy.ndimage" )
23+ cupy , has_cupy = optional_import ("cupy" )
24+ cupy_ndimage , has_cupy_ndimage = optional_import ("cupyx.scipy.ndimage" )
25+
26+
2127__all__ = ["DiceMetric" , "compute_dice" , "DiceHelper" ]
2228
2329
@@ -41,6 +47,18 @@ class DiceMetric(CumulativeIterationMetric):
4147 image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
4248 and ground truth is BCHW[D].
4349
50+ The `per_component=True` approach computes the Dice metric on a per-connected component basis in the ground truth segmentation,
51+ ensuring equal weighting for each component regardless of its size. This method eliminates biases in traditional metrics,
52+ providing a more balanced evaluation, particularly in scenarios where object size does not correlate with clinical relevance.
53+ This provides a more granular evaluation of segmentation quality, especially useful when dealing with fragmented or
54+ disconnected objects in the foreground.
55+ Note:
56+ - The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background),
57+ with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction
58+ and ground truth is B2HW[D].
59+ - This method cannot be used with multiclass segmentation.
60+ For more information, refer to the original paper: https://arxiv.org/abs/2410.18684
61+
4462 The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
4563
4664 Further information can be found in the official
@@ -95,6 +113,9 @@ class DiceMetric(CumulativeIterationMetric):
95113 If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
96114 the index begins at "0", otherwise at "1". It can also take a list of label names.
97115 The outcome will then be returned as a dictionary.
116+ per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
117+ computed for each connected component in the ground truth, and then averaged. This requires binary
118+ segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
98119
99120 """
100121
@@ -106,6 +127,7 @@ def __init__(
106127 ignore_empty : bool = True ,
107128 num_classes : int | None = None ,
108129 return_with_label : bool | list [str ] = False ,
130+ per_component : bool = False ,
109131 ) -> None :
110132 super ().__init__ ()
111133 self .include_background = include_background
@@ -114,13 +136,15 @@ def __init__(
114136 self .ignore_empty = ignore_empty
115137 self .num_classes = num_classes
116138 self .return_with_label = return_with_label
139+ self .per_component = per_component
117140 self .dice_helper = DiceHelper (
118141 include_background = self .include_background ,
119142 reduction = MetricReduction .NONE ,
120143 get_not_nans = False ,
121144 apply_argmax = False ,
122145 ignore_empty = self .ignore_empty ,
123146 num_classes = self .num_classes ,
147+ per_component = self .per_component ,
124148 )
125149
126150 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
@@ -175,6 +199,7 @@ def compute_dice(
175199 include_background : bool = True ,
176200 ignore_empty : bool = True ,
177201 num_classes : int | None = None ,
202+ per_component : bool = False ,
178203) -> torch .Tensor :
179204 """
180205 Computes Dice score metric for a batch of predictions. This performs the same computation as
@@ -192,6 +217,9 @@ def compute_dice(
192217 num_classes: number of input channels (always including the background). When this is ``None``,
193218 ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
194219 single-channel class indices and the number of classes is not automatically inferred from data.
220+ per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
221+ computed for each connected component in the ground truth, and then averaged. This requires binary
222+ segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
195223
196224 Returns:
197225 Dice scores per batch and per class, (shape: [batch_size, num_classes]).
@@ -204,6 +232,7 @@ def compute_dice(
204232 apply_argmax = False ,
205233 ignore_empty = ignore_empty ,
206234 num_classes = num_classes ,
235+ per_component = per_component ,
207236 )(y_pred = y_pred , y = y )
208237
209238
@@ -246,6 +275,9 @@ class DiceHelper:
246275 num_classes: number of input channels (always including the background). When this is ``None``,
247276 ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248277 single-channel class indices and the number of classes is not automatically inferred from data.
278+ per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
279+ computed for each connected component in the ground truth, and then averaged. This requires binary
280+ segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
249281 """
250282
251283 @deprecated_arg ("softmax" , "1.5" , "1.7" , "Use `apply_argmax` instead." , new_name = "apply_argmax" )
@@ -262,6 +294,7 @@ def __init__(
262294 num_classes : int | None = None ,
263295 sigmoid : bool | None = None ,
264296 softmax : bool | None = None ,
297+ per_component : bool = False ,
265298 ) -> None :
266299 # handling deprecated arguments
267300 if sigmoid is not None :
@@ -277,6 +310,50 @@ def __init__(
277310 self .activate = activate
278311 self .ignore_empty = ignore_empty
279312 self .num_classes = num_classes
313+ self .per_component = per_component
314+
315+ def compute_cc_dice (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
316+ """
317+ Compute per-component Dice for a single batch item.
318+
319+ Args:
320+ y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W).
321+ y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W).
322+
323+ Returns:
324+ torch.Tensor: Mean Dice over connected components.
325+ """
326+ if y_pred .ndim == y .ndim :
327+ y_pred_idx = torch .argmax (y_pred , dim = 1 )
328+ y_idx = torch .argmax (y , dim = 1 )
329+ else :
330+ y_pred_idx = y_pred
331+ y_idx = y
332+ if y_idx [0 ].sum () == 0 :
333+ if self .ignore_empty :
334+ data = torch .tensor (float ("nan" ), device = y_idx .device )
335+ elif y_pred_idx .sum () == 0 :
336+ data = torch .tensor (1.0 , device = y_idx .device )
337+ else :
338+ data = torch .tensor (0.0 , device = y_idx .device )
339+ else :
340+ cc_assignment = compute_voronoi_regions_fast (y_idx [0 ])
341+ if cc_assignment .device != y_idx .device :
342+ cc_assignment = cc_assignment .to (y_idx .device )
343+ uniq , inv = torch .unique (cc_assignment .view (- 1 ), return_inverse = True )
344+ nof_components = uniq .numel ()
345+ code = (y_idx .view (- 1 ) << 1 ) | y_pred_idx .view (- 1 )
346+ idx = (inv << 2 ) | code
347+ hist = torch .bincount (idx , minlength = nof_components * 4 ).reshape (- 1 , 4 )
348+ _ , fp , fn , tp = hist [:, 0 ], hist [:, 1 ], hist [:, 2 ], hist [:, 3 ]
349+ denom = 2 * tp + fp + fn
350+ dice_scores = torch .where (
351+ denom > 0 , (2 * tp ).float () / denom .float (), torch .tensor (1.0 , device = denom .device )
352+ )
353+ data = dice_scores .unsqueeze (- 1 )
354+ data = torch .nan_to_num (data )
355+ data = data .reshape (- 1 , 1 )
356+ return torch .stack ([data .mean ()])
280357
281358 def compute_channel (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
282359 """
@@ -305,6 +382,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
305382 y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
306383 the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
307384 y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
385+
386+ Raises:
387+ ValueError: when the shapes of `y_pred` and `y` are not compatible for the per-component computation.
308388 """
309389 _apply_argmax , _threshold = self .apply_argmax , self .threshold
310390 if self .num_classes is None :
@@ -322,15 +402,31 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
322402 y_pred = torch .sigmoid (y_pred )
323403 y_pred = y_pred > 0.5
324404
325- first_ch = 0 if self .include_background else 1
405+ if self .per_component :
406+ if y_pred .ndim not in (4 , 5 ) or y .ndim not in (4 , 5 ) or y_pred .shape [1 ] != 2 or y .shape [1 ] != 2 :
407+ same_rank = y_pred .ndim == y .ndim and y_pred .ndim in (4 , 5 )
408+ binary_channels = y_pred .shape [1 ] == 2 and y .shape [1 ] == 2
409+ same_shape = y_pred .shape == y .shape
410+ if not (same_rank and binary_channels and same_shape ):
411+ raise ValueError (
412+ "per_component requires matching 4D/5D binary tensors "
413+ "(B, 2, H, W) or (B, 2, D, H, W). "
414+ f"Got y_pred={ tuple (y_pred .shape )} , y={ tuple (y .shape )} ."
415+ )
416+
417+ first_ch = 0 if self .include_background and not self .per_component else 1
326418 data = []
327419 for b in range (y_pred .shape [0 ]):
420+ if self .per_component :
421+ data .append (self .compute_cc_dice (y_pred = y_pred [b ].unsqueeze (0 ), y = y [b ].unsqueeze (0 )).reshape (- 1 ))
422+ continue
328423 c_list = []
329424 for c in range (first_ch , n_pred_ch ) if n_pred_ch > 1 else [1 ]:
330425 x_pred = (y_pred [b , 0 ] == c ) if (y_pred .shape [1 ] == 1 ) else y_pred [b , c ].bool ()
331426 x = (y [b , 0 ] == c ) if (y .shape [1 ] == 1 ) else y [b , c ]
332427 c_list .append (self .compute_channel (x_pred , x ))
333428 data .append (torch .stack (c_list ))
429+
334430 data = torch .stack (data , dim = 0 ).contiguous () # type: ignore
335431
336432 f , not_nans = do_metric_reduction (data , self .reduction ) # type: ignore
0 commit comments