Skip to content

Commit 34a6817

Browse files
committed
Adding optional import for scipy and fixing issues raised by coderabbitai - docstring issues, ignore_empty bug
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
1 parent c110e2a commit 34a6817

1 file changed

Lines changed: 22 additions & 9 deletions

File tree

monai/metrics/meandice.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
import numpy as np
1515
import torch
1616
import torch.nn.functional as F
17+
1718
from scipy.ndimage import distance_transform_edt, generate_binary_structure
1819
from scipy.ndimage import label as sn_label
1920

2021
from monai.metrics.utils import do_metric_reduction
2122
from monai.utils import MetricReduction, deprecated_arg
23+
from monai.utils.module import optional_import
2224

2325
from .metric import CumulativeIterationMetric
2426

27+
distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt")
28+
sn_label, _ = optional_import("scipy.ndimage", name="label")
29+
2530
__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
2631

2732

@@ -304,11 +309,15 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
304309
Returns the ID of the nearest component for each voxel.
305310
306311
Args:
307-
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
308-
connectivity: 6/18/26 (3D)
309-
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
310-
"""
312+
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
313+
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
314+
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
311315
316+
Returns:
317+
torch.Tensor: Voronoi region IDs (int32) on CPU.
318+
"""
319+
if not has_ndimage:
320+
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
312321
x = np.asarray(labels)
313322
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
314323
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
@@ -323,12 +332,14 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
323332

324333
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
325334
"""
326-
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
327-
for each batch item and for each channel of those items.
335+
Compute per-component Dice for a single batch item.
328336
329337
Args:
330-
y_pred: input predictions with shape HW[D].
331-
y: ground truth with shape HW[D].
338+
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
339+
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
340+
341+
Returns:
342+
torch.Tensor: Mean Dice over connected components.
332343
"""
333344
data = []
334345
if y_pred.ndim == y.ndim:
@@ -338,7 +349,9 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
338349
y_pred_idx = y_pred
339350
y_idx = y
340351
if y_idx[0].sum() == 0:
341-
if y_pred_idx.sum() == 0:
352+
if self.ignore_empty:
353+
data.append(torch.tensor(float("nan"), device=y_idx.device))
354+
elif y_pred_idx.sum() == 0:
342355
data.append(torch.tensor(1.0, device=y_idx.device))
343356
else:
344357
data.append(torch.tensor(0.0, device=y_idx.device))

0 commit comments

Comments
 (0)