1313
1414import numpy as np
1515import torch
16- import torch .nn .functional as F
17-
1816from scipy .ndimage import distance_transform_edt , generate_binary_structure
1917from scipy .ndimage import label as sn_label
2018
2119from monai .metrics .utils import do_metric_reduction
2220from monai .utils import MetricReduction , deprecated_arg
23- from monai .utils .module import optional_import
2421
2522from .metric import CumulativeIterationMetric
2623
27- distance_transform_edt , has_ndimage = optional_import ("scipy.ndimage" , name = "distance_transform_edt" )
28- sn_label , _ = optional_import ("scipy.ndimage" , name = "label" )
29-
3024__all__ = ["DiceMetric" , "compute_dice" , "DiceHelper" ]
3125
3226
@@ -309,15 +303,11 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
309303 Returns the ID of the nearest component for each voxel.
310304
311305 Args:
312- labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
313- connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
314- sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
315-
316- Returns:
317- torch.Tensor: Voronoi region IDs (int32) on CPU.
306+ labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
307+ connectivity: 6/18/26 (3D)
308+ sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
318309 """
319- if not has_ndimage :
320- raise RuntimeError ("scipy.ndimage is required for per_component Dice computation." )
310+
321311 x = np .asarray (labels )
322312 conn_rank = {6 : 1 , 18 : 2 , 26 : 3 }.get (connectivity , 3 )
323313 structure = generate_binary_structure (rank = 3 , connectivity = conn_rank )
@@ -332,14 +322,12 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
332322
333323 def compute_cc_dice (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
334324 """
335- Compute per-component Dice for a single batch item.
325+ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
326+ for each batch item and for each channel of those items.
336327
337328 Args:
338- y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
339- y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
340-
341- Returns:
342- torch.Tensor: Mean Dice over connected components.
329+ y_pred: input predictions with shape HW[D].
330+ y: ground truth with shape HW[D].
343331 """
344332 data = []
345333 if y_pred .ndim == y .ndim :
@@ -349,9 +337,7 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
349337 y_pred_idx = y_pred
350338 y_idx = y
351339 if y_idx [0 ].sum () == 0 :
352- if self .ignore_empty :
353- data .append (torch .tensor (float ("nan" ), device = y_idx .device ))
354- elif y_pred_idx .sum () == 0 :
340+ if y_pred_idx .sum () == 0 :
355341 data .append (torch .tensor (1.0 , device = y_idx .device ))
356342 else :
357343 data .append (torch .tensor (0.0 , device = y_idx .device ))
0 commit comments