Skip to content

Commit 3623f96

Browse files
committed
Validate imratio and input shape, added test cases for it and fix non-binary target test
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
1 parent 38d8169 commit 3623f96

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

monai/losses/aucm_loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
6666
Raises:
6767
ValueError: When ``version`` is not one of ["v1", "v2"].
68+
ValueError: When ``imratio`` is not in [0, 1].
6869
6970
Example:
7071
>>> import torch
@@ -77,6 +78,8 @@ def __init__(
7778
super().__init__(reduction=LossReduction(reduction).value)
7879
if version not in ["v1", "v2"]:
7980
raise ValueError(f"version should be 'v1' or 'v2', got {version}")
81+
if imratio is not None and not (0.0 <= imratio <= 1.0):
82+
raise ValueError(f"imratio must be in [0, 1], got {imratio}")
8083
self.margin = margin
8184
self.imratio = imratio
8285
self.version = version
@@ -95,8 +98,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
9598
9699
Raises:
97100
ValueError: When input or target have incorrect shapes.
101+
ValueError: When input or target have fewer than 2 dimensions.
98102
ValueError: When target contains non-binary values.
99103
"""
104+
if input.ndim < 2 or target.ndim < 2:
105+
raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)")
100106
if input.shape[1] != 1:
101107
raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}")
102108
if target.shape[1] != 1:

tests/losses/test_aucm_loss.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def test_invalid_version(self):
4545
with self.assertRaises(ValueError):
4646
AUCMLoss(version="invalid")
4747

48+
def test_invalid_imratio(self):
49+
"""Test that invalid imratio raises ValueError."""
50+
with self.assertRaises(ValueError):
51+
AUCMLoss(imratio=1.5)
52+
with self.assertRaises(ValueError):
53+
AUCMLoss(imratio=-0.1)
54+
4855
def test_invalid_input_shape(self):
4956
"""Test that invalid input shape raises ValueError."""
5057
loss_fn = AUCMLoss()
@@ -61,6 +68,14 @@ def test_invalid_target_shape(self):
6168
with self.assertRaises(ValueError):
6269
loss_fn(input, target)
6370

71+
def test_insufficient_dimensions(self):
72+
"""Test that tensors with insufficient dimensions raise ValueError."""
73+
loss_fn = AUCMLoss()
74+
input = torch.randn(32) # 1D tensor
75+
target = torch.randint(0, 2, (32, 1)).float()
76+
with self.assertRaises(ValueError):
77+
loss_fn(input, target)
78+
6479
def test_shape_mismatch(self):
6580
"""Test that mismatched shapes raise ValueError."""
6681
loss_fn = AUCMLoss()
@@ -73,7 +88,7 @@ def test_non_binary_target(self):
7388
"""Test that non-binary target values raise ValueError."""
7489
loss_fn = AUCMLoss()
7590
input = torch.randn(32, 1)
76-
target = torch.tensor([[0.5], [1.0], [2.0]] * 10 + [[0.0]]) # Contains non-binary values
91+
target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary
7792
with self.assertRaises(ValueError):
7893
loss_fn(input, target)
7994

0 commit comments

Comments
 (0)