Skip to content

Commit a2f5101

Browse files
committed
Address coderabbitai review comments for AUCMLoss and tests
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
1 parent 25c2702 commit a2f5101

2 files changed

Lines changed: 60 additions & 28 deletions

File tree

monai/losses/aucm_loss.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
113113
input = input.flatten()
114114
target = target.flatten()
115115

116+
if input.numel() == 0:
117+
raise ValueError("Input and target must contain at least one element.")
118+
116119
if not torch.all((target == 0) | (target == 1)):
117120
raise ValueError("Target must contain only binary values (0 or 1)")
118121

@@ -175,7 +178,10 @@ def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor
175178
Returns:
176179
Scalar tensor representing the global mean.
177180
"""
178-
return (tensor * mask).mean()
181+
masked = tensor * mask
182+
if masked.numel() == 0:
183+
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
184+
return masked.mean()
179185

180186
def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
181187
"""

tests/losses/test_aucm_loss.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,63 @@
1313

1414
import unittest
1515

16+
import numpy as np
1617
import torch
1718
from parameterized import parameterized
1819

1920
from monai.losses import AUCMLoss
2021
from tests.test_utils import test_script_save
2122

22-
FIXED_INPUT = torch.tensor([[1.0], [2.0]])
23-
FIXED_TARGET = torch.tensor([[1.0], [0.0]])
24-
25-
EXPECTED_V1 = 1.25
26-
EXPECTED_V2 = 5.0
23+
TEST_CASES = [
24+
# small deterministic cases (with expected values)
25+
("v1", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 1.25),
26+
("v2", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 5.0),
27+
]
2728

2829

2930
class TestAUCMLoss(unittest.TestCase):
30-
"""Test cases for AUCMLoss."""
31+
"""Unit tests for AUCMLoss covering correctness, edge cases, and scriptability."""
3132

3233
@parameterized.expand([("v1",), ("v2",)])
3334
def test_versions(self, version):
3435
"""Test AUCMLoss with different versions."""
3536
loss_fn = AUCMLoss(version=version)
36-
input = torch.randn(32, 1, requires_grad=True)
37+
pred = torch.randn(32, 1, requires_grad=True)
3738
target = torch.randint(0, 2, (32, 1)).float()
38-
loss = loss_fn(input, target)
39+
loss = loss_fn(pred, target)
3940
self.assertIsInstance(loss, torch.Tensor)
4041
self.assertEqual(loss.ndim, 0)
4142

42-
@parameterized.expand([("v1", EXPECTED_V1), ("v2", EXPECTED_V2)])
43-
def test_known_values(self, version, expected):
43+
@parameterized.expand(TEST_CASES)
44+
def test_known_values(self, version, pred, target, expected):
4445
"""Test AUCMLoss against fixed manually computed values."""
45-
loss = AUCMLoss(version=version)(FIXED_INPUT, FIXED_TARGET)
46-
self.assertAlmostEqual(loss.item(), expected, places=5)
46+
loss = AUCMLoss(version=version)(pred, target)
47+
np.testing.assert_allclose(loss.detach().cpu().numpy(), expected, atol=1e-5, rtol=1e-5)
48+
49+
@parameterized.expand([("v1",), ("v2",)])
50+
def test_high_dimensional(self, version):
51+
"""Test AUCMLoss with higher dimensional preds (e.g., segmentation)."""
52+
loss_fn = AUCMLoss(version=version)
53+
54+
pred = torch.randn(2, 1, 8, 8, requires_grad=True)
55+
target = torch.randint(0, 2, (2, 1, 8, 8)).float()
56+
57+
loss = loss_fn(pred, target)
58+
59+
self.assertIsInstance(loss, torch.Tensor)
60+
self.assertEqual(loss.ndim, 0)
61+
62+
def test_imbalanced(self):
63+
"""Test AUCMLoss with highly imbalanced targets."""
64+
loss_fn = AUCMLoss(version="v1")
65+
66+
pred = torch.randn(32, 1)
67+
target = torch.zeros(32, 1)
68+
target[0] = 1.0 # only one positive
69+
70+
loss = loss_fn(pred, target)
71+
72+
self.assertIsInstance(loss, torch.Tensor)
4773

4874
def test_invalid_version(self):
4975
"""Test that invalid version raises ValueError."""
@@ -57,54 +83,54 @@ def test_invalid_imratio(self):
5783
with self.assertRaises(ValueError):
5884
AUCMLoss(imratio=-0.1)
5985

60-
def test_invalid_input_shape(self):
61-
"""Test that invalid input shape raises ValueError."""
86+
def test_invalid_pred_shape(self):
87+
"""Test that invalid pred shape raises ValueError."""
6288
loss_fn = AUCMLoss()
63-
input = torch.randn(32, 2) # Wrong channel
89+
pred = torch.randn(32, 2) # Wrong channel
6490
target = torch.randint(0, 2, (32, 1)).float()
6591
with self.assertRaises(ValueError):
66-
loss_fn(input, target)
92+
loss_fn(pred, target)
6793

6894
def test_invalid_target_shape(self):
6995
"""Test that invalid target shape raises ValueError."""
7096
loss_fn = AUCMLoss()
71-
input = torch.randn(32, 1)
97+
pred = torch.randn(32, 1)
7298
target = torch.randint(0, 2, (32, 2)).float() # Wrong channel
7399
with self.assertRaises(ValueError):
74-
loss_fn(input, target)
100+
loss_fn(pred, target)
75101

76102
def test_insufficient_dimensions(self):
77103
"""Test that tensors with insufficient dimensions raise ValueError."""
78104
loss_fn = AUCMLoss()
79-
input = torch.randn(32) # 1D tensor
105+
pred = torch.randn(32) # 1D tensor
80106
target = torch.randint(0, 2, (32, 1)).float()
81107
with self.assertRaises(ValueError):
82-
loss_fn(input, target)
108+
loss_fn(pred, target)
83109

84110
def test_shape_mismatch(self):
85111
"""Test that mismatched shapes raise ValueError."""
86112
loss_fn = AUCMLoss()
87-
input = torch.randn(32, 1)
113+
pred = torch.randn(32, 1)
88114
target = torch.randint(0, 2, (16, 1)).float()
89115
with self.assertRaises(ValueError):
90-
loss_fn(input, target)
116+
loss_fn(pred, target)
91117

92118
def test_non_binary_target(self):
93119
"""Test that non-binary target values raise ValueError."""
94120
loss_fn = AUCMLoss()
95-
input = torch.randn(32, 1)
121+
pred = torch.randn(32, 1)
96122
target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary
97123
with self.assertRaises(ValueError):
98-
loss_fn(input, target)
124+
loss_fn(pred, target)
99125

100126
def test_backward(self):
101127
"""Test that gradients can be computed."""
102128
loss_fn = AUCMLoss()
103-
input = torch.randn(32, 1, requires_grad=True)
129+
pred = torch.randn(32, 1, requires_grad=True)
104130
target = torch.randint(0, 2, (32, 1)).float()
105-
loss = loss_fn(input, target)
131+
loss = loss_fn(pred, target)
106132
loss.backward()
107-
self.assertIsNotNone(input.grad)
133+
self.assertIsNotNone(pred.grad)
108134

109135
def test_script_save(self):
110136
"""Test that the loss can be saved as TorchScript."""

0 commit comments

Comments
 (0)