Skip to content

Commit 25c2702

Browse files
committed
Refactor AUCMLoss implementation and improve tests
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
1 parent 3623f96 commit 25c2702

2 files changed

Lines changed: 85 additions & 29 deletions

File tree

monai/losses/aucm_loss.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -119,32 +119,83 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
119119
pos_mask = (target == 1).float()
120120
neg_mask = (target == 0).float()
121121

122+
mean_pos_sq = (input - self.a) ** 2
123+
mean_neg_sq = (input - self.b) ** 2
124+
125+
# Note:
126+
# v1 uses global expectations (normalized by total number of samples),
127+
# following the original LibAUC implementation.
128+
# v2 uses class-conditional expectations (normalized by number of samples
129+
# in each class), implemented via non-zero averaging.
130+
# These behaviors differ and should not be unified.
122131
if self.version == "v1":
123132
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
133+
p1 = 1.0 - p
134+
135+
mean_pos = self._global_mean(mean_pos_sq, pos_mask)
136+
mean_neg = self._global_mean(mean_neg_sq, neg_mask)
137+
138+
interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask)
139+
124140
loss = (
125-
(1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask)
126-
+ p * self._safe_mean((input - self.b) ** 2, neg_mask)
127-
+ 2
128-
* self.alpha
129-
* (
130-
p * (1 - p) * self.margin
131-
+ self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask)
132-
)
133-
- p * (1 - p) * self.alpha**2
141+
p1 * mean_pos
142+
+ p * mean_neg
143+
+ 2 * self.alpha * (p * p1 * self.margin + interaction)
144+
- p * p1 * self.alpha**2
134145
)
135-
else:
146+
147+
else: # v2
148+
mean_pos = self._class_mean(mean_pos_sq, pos_mask)
149+
mean_neg = self._class_mean(mean_neg_sq, neg_mask)
150+
151+
mean_input_pos = self._class_mean(input, pos_mask)
152+
mean_input_neg = self._class_mean(input, neg_mask)
153+
136154
loss = (
137-
self._safe_mean((input - self.a) ** 2, pos_mask)
138-
+ self._safe_mean((input - self.b) ** 2, neg_mask)
139-
+ 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask))
140-
- self.alpha**2
155+
mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2
141156
)
142157

143158
return loss
144159

145-
def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
146-
"""Compute mean safely over masked elements."""
160+
def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
161+
"""
162+
Compute the global mean of a masked tensor.
163+
164+
This computes the mean over all elements, where values outside the mask
165+
are zeroed out. The result is normalized by the total number of elements,
166+
not by the number of masked elements.
167+
168+
This corresponds to a global expectation:
169+
E[mask * tensor]
170+
171+
Args:
172+
tensor: Input tensor.
173+
mask: Binary mask tensor of the same shape as ``tensor``.
174+
175+
Returns:
176+
Scalar tensor representing the global mean.
177+
"""
178+
return (tensor * mask).mean()
179+
180+
def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
181+
"""
182+
Compute the class-conditional mean of a masked tensor.
183+
184+
This computes the mean over only the masked (non-zero) elements, i.e.,
185+
the result is normalized by the number of masked elements.
186+
187+
This corresponds to a class-conditional expectation:
188+
E[tensor | mask]
189+
190+
Args:
191+
tensor: Input tensor.
192+
mask: Binary mask tensor of the same shape as ``tensor``.
193+
194+
Returns:
195+
Scalar tensor representing the class-conditional mean.
196+
Returns 0 if no elements are selected by the mask.
197+
"""
147198
denom = mask.sum()
148-
if denom == 0:
149-
return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
199+
if denom.item() == 0:
200+
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
150201
return (tensor * mask).sum() / denom

tests/losses/test_aucm_loss.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,36 @@
1414
import unittest
1515

1616
import torch
17+
from parameterized import parameterized
1718

1819
from monai.losses import AUCMLoss
1920
from tests.test_utils import test_script_save
2021

22+
FIXED_INPUT = torch.tensor([[1.0], [2.0]])
23+
FIXED_TARGET = torch.tensor([[1.0], [0.0]])
24+
25+
EXPECTED_V1 = 1.25
26+
EXPECTED_V2 = 5.0
27+
2128

2229
class TestAUCMLoss(unittest.TestCase):
2330
"""Test cases for AUCMLoss."""
2431

25-
def test_v1(self):
26-
"""Test AUCMLoss with version 'v1'."""
27-
loss_fn = AUCMLoss(version="v1")
32+
@parameterized.expand([("v1",), ("v2",)])
33+
def test_versions(self, version):
34+
"""Test AUCMLoss with different versions."""
35+
loss_fn = AUCMLoss(version=version)
2836
input = torch.randn(32, 1, requires_grad=True)
2937
target = torch.randint(0, 2, (32, 1)).float()
3038
loss = loss_fn(input, target)
3139
self.assertIsInstance(loss, torch.Tensor)
3240
self.assertEqual(loss.ndim, 0)
3341

34-
def test_v2(self):
35-
"""Test AUCMLoss with version 'v2'."""
36-
loss_fn = AUCMLoss(version="v2")
37-
input = torch.randn(32, 1, requires_grad=True)
38-
target = torch.randint(0, 2, (32, 1)).float()
39-
loss = loss_fn(input, target)
40-
self.assertIsInstance(loss, torch.Tensor)
41-
self.assertEqual(loss.ndim, 0)
42+
@parameterized.expand([("v1", EXPECTED_V1), ("v2", EXPECTED_V2)])
43+
def test_known_values(self, version, expected):
44+
"""Test AUCMLoss against fixed manually computed values."""
45+
loss = AUCMLoss(version=version)(FIXED_INPUT, FIXED_TARGET)
46+
self.assertAlmostEqual(loss.item(), expected, places=5)
4247

4348
def test_invalid_version(self):
4449
"""Test that invalid version raises ValueError."""

0 commit comments

Comments
 (0)