Skip to content

Commit 8d412a1

Browse files
committed
fixing indentation and formatting
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
2 parents 34a6817 + 41e52c1 commit 8d412a1

2 files changed

Lines changed: 9 additions & 24 deletions

File tree

monai/metrics/meandice.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,14 @@
1313

1414
import numpy as np
1515
import torch
16-
import torch.nn.functional as F
17-
1816
from scipy.ndimage import distance_transform_edt, generate_binary_structure
1917
from scipy.ndimage import label as sn_label
2018

2119
from monai.metrics.utils import do_metric_reduction
2220
from monai.utils import MetricReduction, deprecated_arg
23-
from monai.utils.module import optional_import
2421

2522
from .metric import CumulativeIterationMetric
2623

27-
distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt")
28-
sn_label, _ = optional_import("scipy.ndimage", name="label")
29-
3024
__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
3125

3226

@@ -309,15 +303,11 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
309303
Returns the ID of the nearest component for each voxel.
310304
311305
Args:
312-
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
313-
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
314-
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
315-
316-
Returns:
317-
torch.Tensor: Voronoi region IDs (int32) on CPU.
306+
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
307+
connectivity: 6/18/26 (3D)
308+
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
318309
"""
319-
if not has_ndimage:
320-
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
310+
321311
x = np.asarray(labels)
322312
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
323313
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
@@ -332,14 +322,12 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
332322

333323
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
334324
"""
335-
Compute per-component Dice for a single batch item.
325+
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
326+
for each batch item and for each channel of those items.
336327
337328
Args:
338-
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
339-
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
340-
341-
Returns:
342-
torch.Tensor: Mean Dice over connected components.
329+
y_pred: input predictions with shape HW[D].
330+
y: ground truth with shape HW[D].
343331
"""
344332
data = []
345333
if y_pred.ndim == y.ndim:
@@ -349,9 +337,7 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
349337
y_pred_idx = y_pred
350338
y_idx = y
351339
if y_idx[0].sum() == 0:
352-
if self.ignore_empty:
353-
data.append(torch.tensor(float("nan"), device=y_idx.device))
354-
elif y_pred_idx.sum() == 0:
340+
if y_pred_idx.sum() == 0:
355341
data.append(torch.tensor(1.0, device=y_idx.device))
356342
else:
357343
data.append(torch.tensor(0.0, device=y_idx.device))

tests/metrics/test_compute_meandice.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from parameterized import parameterized
1919

2020
from monai.metrics import DiceHelper, DiceMetric, compute_dice
21-
from monai.metrics.fid import FIDMetric
2221

2322
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
2423
# keep background

0 commit comments

Comments
 (0)