@@ -59,22 +59,36 @@ def __init__(
5959 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
6060 n_pred_ch = y_pred .shape [1 ]
6161
62+ # Save original for masking
63+ original_y_true = y_true if self .ignore_index is not None else None
64+
6265 if self .to_onehot_y :
6366 if n_pred_ch == 1 :
6467 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
6568 else :
69+ if self .ignore_index is not None :
70+ # Replace ignore_index with valid class before one_hot
71+ y_true = torch .where (y_true == self .ignore_index , torch .tensor (0 , device = y_true .device ), y_true )
6672 y_true = one_hot (y_true , num_classes = n_pred_ch )
6773
6874 if y_true .shape != y_pred .shape :
6975 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
7076
71- # Handle ignore_index:
77+ # Build mask after one_hot conversion
7278 mask = torch .ones_like (y_true )
7379 if self .ignore_index is not None :
74- # Identify valid pixels: where at least one channel is 1
75- spatial_mask = (torch .sum (y_true , dim = 1 , keepdim = True ) > 0 ).float ()
80+ if original_y_true is not None and self .to_onehot_y :
81+ # Use original labels to build spatial mask
82+ spatial_mask = (original_y_true != self .ignore_index ).float ()
83+ elif self .ignore_index < y_true .shape [1 ]:
84+ # For already one-hot: use ignored class channel
85+ spatial_mask = 1.0 - y_true [:, self .ignore_index : self .ignore_index + 1 ]
86+ else :
87+ # For sentinel values: any valid channel
88+ spatial_mask = (y_true .sum (dim = 1 , keepdim = True ) > 0 ).float ()
7689 mask = spatial_mask .expand_as (y_true )
7790 y_pred = y_pred * mask
91+ y_true = y_true * mask
7892
7993 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
8094 axis = list (range (2 , len (y_pred .shape )))
@@ -137,15 +151,16 @@ def __init__(
137151 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
138152 n_pred_ch = y_pred .shape [1 ]
139153
154+ # Save original for masking
155+ original_y_true = y_true if self .ignore_index is not None else None
156+
140157 if self .to_onehot_y :
141158 if n_pred_ch == 1 :
142159 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
143- elif self .ignore_index is not None :
144- mask = (y_true != self .ignore_index ).float ()
145- y_true_clean = torch .where (y_true == self .ignore_index , 0 , y_true )
146- y_true = one_hot (y_true_clean , num_classes = n_pred_ch )
147- y_true = y_true * mask
148160 else :
161+ if self .ignore_index is not None :
162+ # Replace ignore_index with valid class before one_hot
163+ y_true = torch .where (y_true == self .ignore_index , torch .tensor (0 , device = y_true .device ), y_true )
149164 y_true = one_hot (y_true , num_classes = n_pred_ch )
150165
151166 if y_true .shape != y_pred .shape :
@@ -154,9 +169,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
154169 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
155170 cross_entropy = - y_true * torch .log (y_pred )
156171
172+ # Build mask from original labels if available
173+ spatial_mask = None
157174 if self .ignore_index is not None :
158- spatial_mask = (torch .sum (y_true , dim = 1 , keepdim = True ) > 0 ).float ()
159- cross_entropy = cross_entropy * spatial_mask
175+ if original_y_true is not None and self .to_onehot_y :
176+ spatial_mask = (original_y_true != self .ignore_index ).float ()
177+ elif self .ignore_index < y_true .shape [1 ]:
178+ spatial_mask = 1.0 - y_true [:, self .ignore_index : self .ignore_index + 1 ]
179+ else :
180+ spatial_mask = (y_true .sum (dim = 1 , keepdim = True ) > 0 ).float ()
181+ cross_entropy = cross_entropy * spatial_mask .expand_as (cross_entropy )
160182
161183 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
162184 back_ce = (1 - self .delta ) * back_ce
@@ -165,10 +187,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
165187 fore_ce = self .delta * fore_ce
166188
167189 loss = torch .stack ([back_ce , fore_ce ], dim = 1 ) # [B, 2, H, W]
190+
168191 if self .reduction == LossReduction .MEAN .value :
169- if self .ignore_index is not None :
170- # Normalize by the number of non-ignored pixels
171- return loss .sum () / spatial_mask .sum ().clamp (min = 1e-5 )
192+ if self .ignore_index is not None and spatial_mask is not None :
193+ # Apply mask to loss, then average over valid elements only
194+ # loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
195+ masked_loss = loss * spatial_mask .expand_as (loss )
196+ return masked_loss .sum () / (spatial_mask .expand_as (loss ).sum ().clamp (min = 1e-5 ))
172197 return loss .mean ()
173198 if self .reduction == LossReduction .SUM .value :
174199 return loss .sum ()
0 commit comments