Skip to content

Commit 780b567

Browse files
committed
fix: complete ignore_index implementation with proper one-hot masking
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 187da14 commit 780b567

7 files changed

Lines changed: 93 additions & 31 deletions

File tree

monai/losses/unified_focal_loss.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,36 @@ def __init__(
5959
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6060
n_pred_ch = y_pred.shape[1]
6161

62+
# Save original for masking
63+
original_y_true = y_true if self.ignore_index is not None else None
64+
6265
if self.to_onehot_y:
6366
if n_pred_ch == 1:
6467
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
6568
else:
69+
if self.ignore_index is not None:
70+
# Replace ignore_index with valid class before one_hot
71+
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
6672
y_true = one_hot(y_true, num_classes=n_pred_ch)
6773

6874
if y_true.shape != y_pred.shape:
6975
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
7076

71-
# Handle ignore_index:
77+
# Build mask after one_hot conversion
7278
mask = torch.ones_like(y_true)
7379
if self.ignore_index is not None:
74-
# Identify valid pixels: where at least one channel is 1
75-
spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float()
80+
if original_y_true is not None and self.to_onehot_y:
81+
# Use original labels to build spatial mask
82+
spatial_mask = (original_y_true != self.ignore_index).float()
83+
elif self.ignore_index < y_true.shape[1]:
84+
# For already one-hot: use ignored class channel
85+
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
86+
else:
87+
# For sentinel values: any valid channel
88+
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
7689
mask = spatial_mask.expand_as(y_true)
7790
y_pred = y_pred * mask
91+
y_true = y_true * mask
7892

7993
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
8094
axis = list(range(2, len(y_pred.shape)))
@@ -137,15 +151,16 @@ def __init__(
137151
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
138152
n_pred_ch = y_pred.shape[1]
139153

154+
# Save original for masking
155+
original_y_true = y_true if self.ignore_index is not None else None
156+
140157
if self.to_onehot_y:
141158
if n_pred_ch == 1:
142159
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
143-
elif self.ignore_index is not None:
144-
mask = (y_true != self.ignore_index).float()
145-
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
146-
y_true = one_hot(y_true_clean, num_classes=n_pred_ch)
147-
y_true = y_true * mask
148160
else:
161+
if self.ignore_index is not None:
162+
# Replace ignore_index with valid class before one_hot
163+
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
149164
y_true = one_hot(y_true, num_classes=n_pred_ch)
150165

151166
if y_true.shape != y_pred.shape:
@@ -154,9 +169,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
154169
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
155170
cross_entropy = -y_true * torch.log(y_pred)
156171

172+
# Build mask from original labels if available
173+
spatial_mask = None
157174
if self.ignore_index is not None:
158-
spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float()
159-
cross_entropy = cross_entropy * spatial_mask
175+
if original_y_true is not None and self.to_onehot_y:
176+
spatial_mask = (original_y_true != self.ignore_index).float()
177+
elif self.ignore_index < y_true.shape[1]:
178+
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
179+
else:
180+
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
181+
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)
160182

161183
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
162184
back_ce = (1 - self.delta) * back_ce
@@ -165,10 +187,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
165187
fore_ce = self.delta * fore_ce
166188

167189
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]
190+
168191
if self.reduction == LossReduction.MEAN.value:
169-
if self.ignore_index is not None:
170-
# Normalize by the number of non-ignored pixels
171-
return loss.sum() / spatial_mask.sum().clamp(min=1e-5)
192+
if self.ignore_index is not None and spatial_mask is not None:
193+
# Apply mask to loss, then average over valid elements only
194+
# loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
195+
masked_loss = loss * spatial_mask.expand_as(loss)
196+
return masked_loss.sum() / (spatial_mask.expand_as(loss).sum().clamp(min=1e-5))
172197
return loss.mean()
173198
if self.reduction == LossReduction.SUM.value:
174199
return loss.sum()

monai/metrics/generalized_dice.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,12 @@ def compute_generalized_dice(
156156

157157
# Apply ignore_index masking
158158
if ignore_index is not None:
159-
mask = (y != ignore_index).all(dim=1, keepdim=True).float()
159+
if ignore_index < y.shape[1]:
160+
# For one-hot: use the ignored class channel
161+
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
162+
else:
163+
# For sentinel values, check if any channel is valid
164+
mask = (y.sum(dim=1, keepdim=True) > 0).float()
160165
y_pred = y_pred * mask
161166
y = y * mask
162167

monai/metrics/meandice.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,17 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
338338
# Create global mask for ignored voxels if ignore_index is set
339339
mask = None
340340
if self.ignore_index is not None:
341-
mask = y != self.ignore_index
341+
if y.shape[1] == 1:
342+
# Single channel - values are class indices
343+
mask = y != self.ignore_index
344+
else:
345+
# Multi-channel (one-hot or class probabilities)
346+
if self.ignore_index < n_pred_ch:
347+
# Class-based ignore: ignore specific class channel
348+
mask = y[:, self.ignore_index : self.ignore_index + 1] == 0
349+
else:
350+
# Sentinel-based ignore: ignore where all channels are 0
351+
mask = y.sum(dim=1, keepdim=True) > 0
342352

343353
first_ch = 0 if self.include_background else 1
344354
data = []

monai/metrics/meaniou.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,17 @@ def compute_iou(
144144
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
145145

146146
if ignore_index is not None:
147-
mask = (y != ignore_index).float()
148-
if mask.shape != y_pred.shape:
147+
if ignore_index < y.shape[1]:
148+
# For one-hot: mask based on the ignored class channel
149+
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
150+
if mask.shape != y_pred.shape:
151+
mask = mask.expand_as(y_pred)
152+
else:
153+
# For sentinel values, check if any channel is valid
154+
mask = (y.sum(dim=1, keepdim=True) > 0).float()
149155
mask = mask.expand_as(y_pred)
150156
y_pred = y_pred * mask
151-
y = torch.where(y == ignore_index, torch.tensor(0, device=y.device), y)
157+
y = y * mask
152158

153159
# reducing only spatial dimensions (not batch nor channels)
154160
n_len = len(y_pred.shape)

monai/metrics/surface_dice.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ def compute_surface_dice(
221221
:math:`b` and class :math:`c`.
222222
"""
223223
if ignore_index is not None:
224-
mask = (y != ignore_index).all(dim=1, keepdim=True).float()
225-
224+
if ignore_index < y.shape[1]:
225+
# For one-hot: mask based on the ignored class channel
226+
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
227+
else:
228+
# For sentinel values, check if any channel is valid
229+
mask = (y.sum(dim=1, keepdim=True) > 0).float()
226230
y_pred = y_pred * mask
227231
y = y * mask
228232

@@ -291,7 +295,7 @@ def compute_surface_dice(
291295
boundary_complete = areas_gt.sum() + areas_pred.sum()
292296
gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
293297
pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
294-
boundary_correct = gt_true + pred_true
298+
boundary_correct = gt_true + pred_true # type: ignore[assignment,operator]
295299
if boundary_complete == 0:
296300
# the class is neither present in the prediction, nor in the reference segmentation
297301
nsd[b, c] = torch.tensor(np.nan)

monai/metrics/utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,13 @@ def get_surface_distance(
309309
raise ValueError(f"distance_metric {distance_metric} is not implemented.")
310310

311311
dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0]
312-
out = dis[seg_pred.bool()]
313-
return out if out is not None else dis.new_empty((0,))
312+
if isinstance(seg_pred, torch.Tensor):
313+
out = dis[seg_pred.bool()] # type: ignore[union-attr]
314+
return out if out is not None else np.empty((0,), dtype=dis.dtype) # type: ignore[union-attr,no-any-return]
315+
else:
316+
# NumPy array
317+
out = dis[seg_pred.astype(bool)] # type: ignore[union-attr]
318+
return out if out is not None else np.empty((0,), dtype=dis.dtype) # type: ignore[union-attr]
314319

315320

316321
def get_edge_surface_distance(
@@ -363,16 +368,19 @@ def get_edge_surface_distance(
363368
edges_pred = edges_pred & mask
364369
edges_gt = edges_gt & mask
365370

366-
distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]
371+
distances_raw: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]
367372
if symmetric:
368-
distances = (
373+
distances_raw = (
369374
get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),
370375
get_surface_distance(edges_gt, edges_pred, distance_metric, spacing),
371376
) # type: ignore
372377
else:
373-
distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore
378+
distances_raw = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore
374379

375-
distances = tuple(d if d is not None else edges_pred.new_empty((0,)) for d in distances)
380+
distances_list = [d if d is not None else edges_pred.new_empty((0,)) for d in distances_raw]
381+
distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] = (
382+
tuple(distances_list) if len(distances_list) == 2 else (distances_list[0],) # type: ignore[assignment]
383+
)
376384

377385
areas = edge_results[2:] if use_subvoxels else ()
378386

@@ -389,7 +397,7 @@ def get_edge_surface_distance(
389397
if out is None:
390398
out = torch.empty((0,), device=y_pred.device)
391399

392-
return out
400+
return out # type: ignore[return-value,no-any-return]
393401

394402

395403
def is_binary_tensor(input: torch.Tensor, name: str) -> None:

tests/metrics/test_ignore_index_metrics.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,14 @@ def test_metric_ignore_consistency(self, metric_class, kwargs):
6060
y_pred2 = y_pred1.clone()
6161
y_pred2[:, 1, 2:4, :] = 1.0 # Bottom half prediction (different!)
6262

63-
# Target: Top half is valid (0/1), Bottom half is 255
63+
# Target: Top half is valid (0/1), Bottom half should be ignored
64+
# For ignore_index=255 (sentinel), we need to mark ignored pixels differently
65+
# Option 1: Use ignore_index as a class index (e.g., ignore_index=1)
66+
# Option 2: Keep one-hot but set ignored region to all zeros
6467
y = torch.zeros((1, 2, 4, 4))
65-
y[:, 1, 0:2, 0:2] = 1.0
66-
y[:, :, 2:4, :] = 255
68+
y[:, 1, 0:2, 0:2] = 1.0 # Top-left is class 1
69+
y[:, 0, 0:2, 2:4] = 1.0 # Top-right is class 0
70+
# Bottom half: leave as all zeros to indicate "no valid class"
6771

6872
# Run metric for both predictions
6973
metric.reset()

0 commit comments

Comments
 (0)