1616import numpy as np
1717import torch
1818
19- from monai .metrics .utils import do_metric_reduction
19+ from monai .metrics .utils import compute_voronoi_regions_fast , do_metric_reduction
2020from monai .utils import MetricReduction , deprecated_arg
2121from monai .utils .module import optional_import
2222
@@ -306,67 +306,6 @@ def __init__(
306306 self .num_classes = num_classes
307307 self .per_component = per_component
308308
309- def compute_voronoi_regions_fast (self , labels ):
310- """
311- Voronoi assignment to connected components (CPU, single EDT) without cc3d.
312- Returns the ID of the nearest component for each voxel.
313-
314- Args:
315- labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
316-
317- Raises:
318- RuntimeError: when `scipy.ndimage` is not available.
319- ValueError: when `labels` has fewer than two dimensions.
320-
321- Returns:
322- torch.Tensor: Voronoi region IDs (int32) on CPU.
323- """
324- if isinstance (labels , torch .Tensor ) and labels .is_cuda and has_cupy and has_cupy_ndimage :
325- xp = cupy
326- nd_distance_transform_edt = cupy_ndimage .distance_transform_edt
327- nd_generate_binary_structure = cupy_ndimage .generate_binary_structure
328- nd_label = cupy_ndimage .label
329- x = cupy .asarray (labels .detach ())
330- else :
331- xp = np
332- nd_distance_transform_edt = scipy_ndimage .distance_transform_edt
333- nd_generate_binary_structure = scipy_ndimage .generate_binary_structure
334- nd_label = scipy_ndimage .label
335-
336- if not has_scipy_ndimage :
337- raise RuntimeError ("scipy.ndimage is required for per_component Dice computation." )
338-
339- if isinstance (labels , torch .Tensor ):
340- warnings .warn (
341- "Voronoi computation is running on CPU. "
342- "To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
343- )
344- x = labels .cpu ().numpy ()
345- else :
346- x = np .asarray (labels )
347- rank = x .ndim
348- if rank == 3 :
349- conn_map = {6 : 1 , 18 : 2 , 26 : 3 }
350- connectivity = 26
351- elif rank == 2 :
352- conn_map = {4 : 1 , 8 : 2 }
353- connectivity = 8
354- else :
355- raise ValueError ("Only 2D or 3D inputs supported" )
356- conn_rank = conn_map .get (connectivity , max (conn_map .values ()))
357- structure = nd_generate_binary_structure (rank = rank , connectivity = conn_rank )
358- cc , num = nd_label (x > 0 , structure = structure )
359- if num == 0 :
360- return torch .zeros_like (torch .from_numpy (x ), dtype = torch .int32 )
361- edt_input = xp .ones (cc .shape , dtype = xp .uint8 )
362- edt_input [cc > 0 ] = 0
363- indices = nd_distance_transform_edt (edt_input , sampling = None , return_distances = False , return_indices = True )
364- voronoi = cc [tuple (indices )]
365- if xp is cupy :
366- return torch .as_tensor (cupy .asnumpy (voronoi ), dtype = torch .int32 )
367- else :
368- return torch .as_tensor (voronoi , dtype = torch .int32 )
369-
370309 def compute_cc_dice (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
371310 """
372311 Compute per-component Dice for a single batch item.
@@ -378,7 +317,6 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
378317 Returns:
379318 torch.Tensor: Mean Dice over connected components.
380319 """
381- data = []
382320 if y_pred .ndim == y .ndim :
383321 y_pred_idx = torch .argmax (y_pred , dim = 1 )
384322 y_idx = torch .argmax (y , dim = 1 )
@@ -387,13 +325,13 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
387325 y_idx = y
388326 if y_idx [0 ].sum () == 0 :
389327 if self .ignore_empty :
390- data . append ( torch .tensor (float ("nan" ), device = y_idx .device ) )
328+ data = torch .tensor (float ("nan" ), device = y_idx .device )
391329 elif y_pred_idx .sum () == 0 :
392- data . append ( torch .tensor (1.0 , device = y_idx .device ) )
330+ data = torch .tensor (1.0 , device = y_idx .device )
393331 else :
394- data . append ( torch .tensor (0.0 , device = y_idx .device ) )
332+ data = torch .tensor (0.0 , device = y_idx .device )
395333 else :
396- cc_assignment = self . compute_voronoi_regions_fast (y_idx [0 ])
334+ cc_assignment = compute_voronoi_regions_fast (y_idx [0 ])
397335 if cc_assignment .device != y_idx .device :
398336 cc_assignment = cc_assignment .to (y_idx .device )
399337 uniq , inv = torch .unique (cc_assignment .view (- 1 ), return_inverse = True )
@@ -406,15 +344,10 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
406344 dice_scores = torch .where (
407345 denom > 0 , (2 * tp ).float () / denom .float (), torch .tensor (1.0 , device = denom .device )
408346 )
409- data .append (dice_scores .unsqueeze (- 1 ))
410- data = [
411- torch .where (torch .isinf (x ), torch .tensor (0.0 , dtype = torch .float32 , device = x .device ), x ) for x in data
412- ]
413- data = [
414- torch .where (torch .isnan (x ), torch .tensor (0.0 , dtype = torch .float32 , device = x .device ), x ) for x in data
415- ]
416- data = [x .reshape (- 1 , 1 ) for x in data ]
417- return torch .stack ([x .mean () for x in data ])
347+ data = dice_scores .unsqueeze (- 1 )
348+ data = torch .nan_to_num (data )
349+ data = data .reshape (- 1 , 1 )
350+ return torch .stack ([data .mean ()])
418351
419352 def compute_channel (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
420353 """
@@ -486,8 +419,6 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
486419 x_pred = (y_pred [b , 0 ] == c ) if (y_pred .shape [1 ] == 1 ) else y_pred [b , c ].bool ()
487420 x = (y [b , 0 ] == c ) if (y .shape [1 ] == 1 ) else y [b , c ]
488421 c_list .append (self .compute_channel (x_pred , x ))
489- # if self.per_component:
490- # c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
491422 data .append (torch .stack (c_list ))
492423
493424 data = torch .stack (data , dim = 0 ).contiguous () # type: ignore
0 commit comments