Skip to content

Commit 925e431

Browse files
committed
moving compute_voronoi_regions_fast to utils, removing hardcoded conn_rank, modifying data assignment in compute_cc_dice and using nan_to_num
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
1 parent 4e6def7 commit 925e431

2 files changed

Lines changed: 67 additions & 79 deletions

File tree

monai/metrics/meandice.py

Lines changed: 9 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717
import torch
1818

19-
from monai.metrics.utils import do_metric_reduction
19+
from monai.metrics.utils import compute_voronoi_regions_fast, do_metric_reduction
2020
from monai.utils import MetricReduction, deprecated_arg
2121
from monai.utils.module import optional_import
2222

@@ -306,67 +306,6 @@ def __init__(
306306
self.num_classes = num_classes
307307
self.per_component = per_component
308308

309-
def compute_voronoi_regions_fast(self, labels):
310-
"""
311-
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
312-
Returns the ID of the nearest component for each voxel.
313-
314-
Args:
315-
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
316-
317-
Raises:
318-
RuntimeError: when `scipy.ndimage` is not available.
319-
ValueError: when `labels` has fewer than two dimensions.
320-
321-
Returns:
322-
torch.Tensor: Voronoi region IDs (int32) on CPU.
323-
"""
324-
if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage:
325-
xp = cupy
326-
nd_distance_transform_edt = cupy_ndimage.distance_transform_edt
327-
nd_generate_binary_structure = cupy_ndimage.generate_binary_structure
328-
nd_label = cupy_ndimage.label
329-
x = cupy.asarray(labels.detach())
330-
else:
331-
xp = np
332-
nd_distance_transform_edt = scipy_ndimage.distance_transform_edt
333-
nd_generate_binary_structure = scipy_ndimage.generate_binary_structure
334-
nd_label = scipy_ndimage.label
335-
336-
if not has_scipy_ndimage:
337-
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
338-
339-
if isinstance(labels, torch.Tensor):
340-
warnings.warn(
341-
"Voronoi computation is running on CPU. "
342-
"To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
343-
)
344-
x = labels.cpu().numpy()
345-
else:
346-
x = np.asarray(labels)
347-
rank = x.ndim
348-
if rank == 3:
349-
conn_map = {6: 1, 18: 2, 26: 3}
350-
connectivity = 26
351-
elif rank == 2:
352-
conn_map = {4: 1, 8: 2}
353-
connectivity = 8
354-
else:
355-
raise ValueError("Only 2D or 3D inputs supported")
356-
conn_rank = conn_map.get(connectivity, max(conn_map.values()))
357-
structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank)
358-
cc, num = nd_label(x > 0, structure=structure)
359-
if num == 0:
360-
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
361-
edt_input = xp.ones(cc.shape, dtype=xp.uint8)
362-
edt_input[cc > 0] = 0
363-
indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True)
364-
voronoi = cc[tuple(indices)]
365-
if xp is cupy:
366-
return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32)
367-
else:
368-
return torch.as_tensor(voronoi, dtype=torch.int32)
369-
370309
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
371310
"""
372311
Compute per-component Dice for a single batch item.
@@ -378,7 +317,6 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
378317
Returns:
379318
torch.Tensor: Mean Dice over connected components.
380319
"""
381-
data = []
382320
if y_pred.ndim == y.ndim:
383321
y_pred_idx = torch.argmax(y_pred, dim=1)
384322
y_idx = torch.argmax(y, dim=1)
@@ -387,13 +325,13 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
387325
y_idx = y
388326
if y_idx[0].sum() == 0:
389327
if self.ignore_empty:
390-
data.append(torch.tensor(float("nan"), device=y_idx.device))
328+
data = torch.tensor(float("nan"), device=y_idx.device)
391329
elif y_pred_idx.sum() == 0:
392-
data.append(torch.tensor(1.0, device=y_idx.device))
330+
data = torch.tensor(1.0, device=y_idx.device)
393331
else:
394-
data.append(torch.tensor(0.0, device=y_idx.device))
332+
data = torch.tensor(0.0, device=y_idx.device)
395333
else:
396-
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
334+
cc_assignment = compute_voronoi_regions_fast(y_idx[0])
397335
if cc_assignment.device != y_idx.device:
398336
cc_assignment = cc_assignment.to(y_idx.device)
399337
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
@@ -406,15 +344,10 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
406344
dice_scores = torch.where(
407345
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
408346
)
409-
data.append(dice_scores.unsqueeze(-1))
410-
data = [
411-
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
412-
]
413-
data = [
414-
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
415-
]
416-
data = [x.reshape(-1, 1) for x in data]
417-
return torch.stack([x.mean() for x in data])
347+
data = dice_scores.unsqueeze(-1)
348+
data = torch.nan_to_num(data)
349+
data = data.reshape(-1, 1)
350+
return torch.stack([data.mean()])
418351

419352
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
420353
"""
@@ -486,8 +419,6 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
486419
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
487420
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
488421
c_list.append(self.compute_channel(x_pred, x))
489-
# if self.per_component:
490-
# c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
491422
data.append(torch.stack(c_list))
492423

493424
data = torch.stack(data, dim=0).contiguous() # type: ignore

monai/metrics/utils.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt")
4040
distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
4141

42+
scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
43+
cupy, has_cupy = optional_import("cupy")
44+
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
45+
4246
__all__ = [
4347
"ignore_background",
4448
"do_metric_reduction",
@@ -320,7 +324,7 @@ def get_edge_surface_distance(
320324
edges_spacing = None
321325
if use_subvoxels:
322326
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
323-
(edges_pred, edges_gt, *areas) = get_mask_edges(
327+
edges_pred, edges_gt, *areas = get_mask_edges(
324328
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
325329
)
326330
if not edges_gt.any():
@@ -462,6 +466,59 @@ def prepare_spacing(
462466
)
463467

464468

469+
def compute_voronoi_regions_fast(labels: np.ndarray | torch.Tensor) -> torch.Tensor:
470+
"""
471+
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
472+
Returns the ID of the nearest component for each voxel.
473+
474+
Args:
475+
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
476+
477+
Raises:
478+
RuntimeError: when `scipy.ndimage` is not available.
479+
ValueError: when `labels` has fewer than two dimensions.
480+
481+
Returns:
482+
torch.Tensor: Voronoi region IDs (int32) on CPU.
483+
"""
484+
if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage:
485+
xp = cupy
486+
nd_distance_transform_edt = cupy_ndimage.distance_transform_edt
487+
nd_generate_binary_structure = cupy_ndimage.generate_binary_structure
488+
nd_label = cupy_ndimage.label
489+
x = cupy.asarray(labels.detach())
490+
else:
491+
xp = np
492+
nd_distance_transform_edt = scipy_ndimage.distance_transform_edt
493+
nd_generate_binary_structure = scipy_ndimage.generate_binary_structure
494+
nd_label = scipy_ndimage.label
495+
496+
if not has_scipy_ndimage:
497+
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
498+
499+
if isinstance(labels, torch.Tensor):
500+
warnings.warn(
501+
"Voronoi computation is running on CPU. "
502+
"To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
503+
)
504+
x = labels.cpu().numpy()
505+
else:
506+
x = np.asarray(labels)
507+
rank = conn_rank = x.ndim
508+
structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank)
509+
cc, num = nd_label(x > 0, structure=structure)
510+
if num == 0:
511+
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
512+
edt_input = xp.ones(cc.shape, dtype=xp.uint8)
513+
edt_input[cc > 0] = 0
514+
indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True)
515+
voronoi = cc[tuple(indices)]
516+
if xp is cupy:
517+
return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32)
518+
else:
519+
return torch.as_tensor(voronoi, dtype=torch.int32)
520+
521+
465522
ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]}
466523

467524

0 commit comments

Comments
 (0)