Skip to content

Commit 91bb2e5

Browse files
committed
fix: resolve mypy union-attr error in unified_focal_loss
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 780b567 commit 91bb2e5

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

monai/losses/unified_focal_loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
170170
cross_entropy = -y_true * torch.log(y_pred)
171171

172172
# Build mask from original labels if available
173-
spatial_mask = None
173+
spatial_mask: torch.Tensor | None = None
174174
if self.ignore_index is not None:
175175
if original_y_true is not None and self.to_onehot_y:
176176
spatial_mask = (original_y_true != self.ignore_index).float()
177177
elif self.ignore_index < y_true.shape[1]:
178178
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
179179
else:
180180
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
181-
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)
181+
182+
if spatial_mask is not None:
183+
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)
182184

183185
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
184186
back_ce = (1 - self.delta) * back_ce

0 commit comments

Comments
 (0)