1414import numpy as np
1515import torch
1616import torch .nn .functional as F
17+
1718from scipy .ndimage import distance_transform_edt , generate_binary_structure
1819from scipy .ndimage import label as sn_label
1920
2021from monai .metrics .utils import do_metric_reduction
2122from monai .utils import MetricReduction , deprecated_arg
23+ from monai .utils .module import optional_import
2224
2325from .metric import CumulativeIterationMetric
2426
27+ distance_transform_edt , has_ndimage = optional_import ("scipy.ndimage" , name = "distance_transform_edt" )
28+ sn_label , _ = optional_import ("scipy.ndimage" , name = "label" )
29+
2530__all__ = ["DiceMetric" , "compute_dice" , "DiceHelper" ]
2631
2732
@@ -304,11 +309,15 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
304309 Returns the ID of the nearest component for each voxel.
305310
306311 Args:
307- labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
308- connectivity: 6/18/26 (3D)
309- sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
310- """
312+ labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
313+ connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
314+ sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
311315
316+ Returns:
317+ torch.Tensor: Voronoi region IDs (int32) on CPU.
318+ """
319+ if not has_ndimage :
320+ raise RuntimeError ("scipy.ndimage is required for per_component Dice computation." )
312321 x = np .asarray (labels )
313322 conn_rank = {6 : 1 , 18 : 2 , 26 : 3 }.get (connectivity , 3 )
314323 structure = generate_binary_structure (rank = 3 , connectivity = conn_rank )
@@ -323,12 +332,14 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
323332
324333 def compute_cc_dice (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
325334 """
326- Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
327- for each batch item and for each channel of those items.
335+ Compute per-component Dice for a single batch item.
328336
329337 Args:
330- y_pred: input predictions with shape HW[D].
331- y: ground truth with shape HW[D].
338+ y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
339+ y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
340+
341+ Returns:
342+ torch.Tensor: Mean Dice over connected components.
332343 """
333344 data = []
334345 if y_pred .ndim == y .ndim :
@@ -338,7 +349,9 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
338349 y_pred_idx = y_pred
339350 y_idx = y
340351 if y_idx [0 ].sum () == 0 :
341- if y_pred_idx .sum () == 0 :
352+ if self .ignore_empty :
353+ data .append (torch .tensor (float ("nan" ), device = y_idx .device ))
354+ elif y_pred_idx .sum () == 0 :
342355 data .append (torch .tensor (1.0 , device = y_idx .device ))
343356 else :
344357 data .append (torch .tensor (0.0 , device = y_idx .device ))
0 commit comments