Skip to content

Commit 092990d

Browse files
committed
Adjust execution order of activation and masking in MaskedDiceLoss
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 57fdd59 commit 092990d

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

monai/losses/dice.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,34 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
253253
target: the shape should be BNH[WD].
254254
mask: the shape should B1H[WD] or 11H[WD].
255255
"""
256-
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
256+
257+
if self.sigmoid:
258+
input = torch.sigmoid(input)
259+
260+
n_pred_ch = input.shape[1]
261+
if self.softmax:
262+
if n_pred_ch == 1:
263+
warnings.warn("single channel prediction, `softmax=True` ignored.")
264+
else:
265+
input = torch.softmax(input, 1)
266+
267+
if self.other_act is not None:
268+
input = self.other_act(input)
269+
270+
was_sigmoid = self.sigmoid
271+
was_softmax = self.softmax
272+
was_other_act = self.other_act
273+
274+
self.sigmoid = False
275+
self.softmax = False
276+
self.other_act = None
277+
278+
try:
279+
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
280+
finally:
281+
self.sigmoid = was_sigmoid
282+
self.softmax = was_softmax
283+
self.other_act = was_other_act
257284

258285

259286
class GeneralizedDiceLoss(_Loss):

tests/losses/test_masked_dice_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
2828
"mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),
2929
},
30-
0.500,
30+
0.333333,
3131
],
3232
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
3333
{"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -36,7 +36,7 @@
3636
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
3737
"mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),
3838
},
39-
0.422969,
39+
0.301128,
4040
],
4141
[ # shape: (2, 2, 3), (2, 1, 3)
4242
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
@@ -54,7 +54,7 @@
5454
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5555
"mask": torch.tensor([[[1.0, 1.0, 0.0]]]),
5656
},
57-
0.47033,
57+
0.579184,
5858
],
5959
[ # shape: (2, 2, 3), (2, 1, 3)
6060
{

0 commit comments

Comments
 (0)