Skip to content

Commit c80eeeb

Browse files
committed
fix: address CodeRabbit critical and major issues
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 170f34a commit c80eeeb

2 files changed

Lines changed: 23 additions & 17 deletions

File tree

monai/metrics/generalized_dice.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ def compute_generalized_dice(
156156

157157
# Apply ignore_index masking
158158
if ignore_index is not None:
159-
if ignore_index < y.shape[1]:
159+
if 0 <= ignore_index < y.shape[1]:
160160
# For one-hot: use the ignored class channel
161161
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
162162
else:
163-
# For sentinel values, check if any channel is valid
163+
# For sentinel values (like 255 or -100), check if any channel is valid
164164
mask = (y.sum(dim=1, keepdim=True) > 0).float()
165165
y_pred = y_pred * mask
166166
y = y * mask
@@ -171,7 +171,7 @@ def compute_generalized_dice(
171171
if not include_background:
172172
channels_to_use.pop(0)
173173

174-
if ignore_index is not None:
174+
if ignore_index is not None and 0 <= ignore_index < n_channels:
175175
# If background was 0 and we ignore class 2, we need the correct absolute index
176176
if ignore_index in channels_to_use:
177177
channels_to_use.remove(ignore_index)
@@ -181,35 +181,33 @@ def compute_generalized_dice(
181181

182182
# Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator
183183
reduce_axis = list(range(2, y_pred.dim()))
184-
y_o_full = torch.sum(y, dim=reduce_axis) # shape: (B, C)
185184
intersection = torch.sum(y[:, channels_to_use, ...] * y_pred[:, channels_to_use, ...], dim=reduce_axis)
186185
y_o = torch.sum(y[:, channels_to_use, ...], dim=reduce_axis)
187186
y_pred_o = torch.sum(y_pred[:, channels_to_use, ...], dim=reduce_axis)
188187

189188
denominator = y_o + y_pred_o
190189

191190
# Set the class weights
191+
# Set the class weights (computed from scored channels only)
192192
weight_type = look_up_option(weight_type, Weight)
193-
y_o_float = y_o_full.float()
193+
y_o_float = y_o.float()
194194

195195
if weight_type == Weight.SIMPLE:
196-
w_full = torch.reciprocal(y_o_float)
196+
w = torch.reciprocal(y_o_float)
197197
elif weight_type == Weight.SQUARE:
198-
w_full = torch.reciprocal(y_o_float * y_o_float)
198+
w = torch.reciprocal(y_o_float * y_o_float)
199199
else:
200-
w_full = torch.ones_like(y_o_float)
200+
w = torch.ones_like(y_o_float)
201201

202202
# Replace infinite values for non-appearing classes by the maximum weight
203-
for b_idx in range(w_full.shape[0]):
204-
batch_w = w_full[b_idx]
203+
for b_idx in range(w.shape[0]):
204+
batch_w = w[b_idx]
205205
infs = torch.isinf(batch_w)
206206
if infs.any():
207207
batch_w[infs] = 0
208208
max_w = torch.max(batch_w)
209209
batch_w[infs] = max_w if max_w > 0 else 1.0
210210

211-
w = w_full[:, channels_to_use]
212-
213211
if sum_over_classes:
214212
intersection = (intersection * w).sum(dim=1, keepdim=True)
215213
denominator = (denominator * w).sum(dim=1, keepdim=True)

monai/metrics/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ def ignore_index_mask(
7575
"""
7676
Masks out the specified ignore_index from both predictions and ground truth.
7777
This is a helper for #8667 to allow 'Ignore Class' functionality in metrics.
78+
79+
Note: This function assumes label-encoded tensors. For one-hot encoded inputs,
80+
use the two-case masking approach as shown in compute_generalized_dice.
81+
82+
Args:
83+
y_pred: predicted tensor with shape (B, C, H, W, [D])
84+
y: ground truth tensor with shape (B, C, H, W, [D])
85+
ignore_index: class index to ignore. If None, returns inputs unchanged.
86+
87+
Returns:
88+
Tuple of (masked y_pred, masked y) where ignored regions are zeroed out.
7889
"""
7990
if ignore_index is None:
8091
return y_pred, y
@@ -284,7 +295,6 @@ def get_surface_distance(
284295
(1) If a single number, isotropic spacing with that value is used.
285296
(2) If a sequence of numbers, the length of the sequence must be equal to the image dimensions.
286297
(3) If ``None``, spacing of unity is used. Defaults to ``None``.
287-
mask: optional boolean mask. Pixels where mask is False will be ignored in the distance computation.
288298
289299
Note:
290300
If seg_pred or seg_gt is all 0, may result in nan/inf distance.
@@ -308,12 +318,10 @@ def get_surface_distance(
308318

309319
dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0]
310320
if isinstance(seg_pred, torch.Tensor):
311-
out = dis[seg_pred.bool()] # type: ignore[union-attr]
312-
return out if out is not None else np.empty((0,), dtype=dis.dtype) # type: ignore[union-attr,no-any-return]
321+
return dis[seg_pred.bool()] # type: ignore[union-attr]
313322
else:
314323
# NumPy array
315-
out = dis[seg_pred.astype(bool)] # type: ignore[union-attr]
316-
return out if out is not None else np.empty((0,), dtype=dis.dtype) # type: ignore[union-attr]
324+
return dis[seg_pred.astype(bool)] # type: ignore[union-attr,no-any-return]
317325

318326

319327
def get_edge_surface_distance(

0 commit comments

Comments
 (0)