1313
1414import unittest
1515
16+ import numpy as np
1617import torch
1718from parameterized import parameterized
1819
1920from monai .losses import AUCMLoss
2021from tests .test_utils import test_script_save
2122
22- FIXED_INPUT = torch . tensor ([[ 1.0 ], [ 2.0 ]])
23- FIXED_TARGET = torch . tensor ([[ 1.0 ], [ 0.0 ]] )
24-
25- EXPECTED_V1 = 1.25
26- EXPECTED_V2 = 5.0
23+ TEST_CASES = [
24+ # small deterministic cases (with expected values )
25+ ( "v1" , torch . tensor ([[ 1.0 ], [ 2.0 ]]), torch . tensor ([[ 1.0 ], [ 0.0 ]]), 1.25 ),
26+ ( "v2" , torch . tensor ([[ 1.0 ], [ 2.0 ]]), torch . tensor ([[ 1.0 ], [ 0.0 ]]), 5.0 ),
27+ ]
2728
2829
2930class TestAUCMLoss (unittest .TestCase ):
30- """Test cases for AUCMLoss."""
31+ """Unit tests for AUCMLoss covering correctness, edge cases, and scriptability ."""
3132
3233 @parameterized .expand ([("v1" ,), ("v2" ,)])
3334 def test_versions (self , version ):
3435 """Test AUCMLoss with different versions."""
3536 loss_fn = AUCMLoss (version = version )
36- input = torch .randn (32 , 1 , requires_grad = True )
37+ pred = torch .randn (32 , 1 , requires_grad = True )
3738 target = torch .randint (0 , 2 , (32 , 1 )).float ()
38- loss = loss_fn (input , target )
39+ loss = loss_fn (pred , target )
3940 self .assertIsInstance (loss , torch .Tensor )
4041 self .assertEqual (loss .ndim , 0 )
4142
42- @parameterized .expand ([( "v1" , EXPECTED_V1 ), ( "v2" , EXPECTED_V2 )] )
43- def test_known_values (self , version , expected ):
43+ @parameterized .expand (TEST_CASES )
44+ def test_known_values (self , version , pred , target , expected ):
4445 """Test AUCMLoss against fixed manually computed values."""
45- loss = AUCMLoss (version = version )(FIXED_INPUT , FIXED_TARGET )
46- self .assertAlmostEqual (loss .item (), expected , places = 5 )
46+ loss = AUCMLoss (version = version )(pred , target )
47+ np .testing .assert_allclose (loss .detach ().cpu ().numpy (), expected , atol = 1e-5 , rtol = 1e-5 )
48+
49+ @parameterized .expand ([("v1" ,), ("v2" ,)])
50+ def test_high_dimensional (self , version ):
51+ """Test AUCMLoss with higher dimensional preds (e.g., segmentation)."""
52+ loss_fn = AUCMLoss (version = version )
53+
54+ pred = torch .randn (2 , 1 , 8 , 8 , requires_grad = True )
55+ target = torch .randint (0 , 2 , (2 , 1 , 8 , 8 )).float ()
56+
57+ loss = loss_fn (pred , target )
58+
59+ self .assertIsInstance (loss , torch .Tensor )
60+ self .assertEqual (loss .ndim , 0 )
61+
62+ def test_imbalanced (self ):
63+ """Test AUCMLoss with highly imbalanced targets."""
64+ loss_fn = AUCMLoss (version = "v1" )
65+
66+ pred = torch .randn (32 , 1 )
67+ target = torch .zeros (32 , 1 )
68+ target [0 ] = 1.0 # only one positive
69+
70+ loss = loss_fn (pred , target )
71+
72+ self .assertIsInstance (loss , torch .Tensor )
4773
4874 def test_invalid_version (self ):
4975 """Test that invalid version raises ValueError."""
@@ -57,54 +83,54 @@ def test_invalid_imratio(self):
5783 with self .assertRaises (ValueError ):
5884 AUCMLoss (imratio = - 0.1 )
5985
60- def test_invalid_input_shape (self ):
61- """Test that invalid input shape raises ValueError."""
86+ def test_invalid_pred_shape (self ):
87+ """Test that invalid pred shape raises ValueError."""
6288 loss_fn = AUCMLoss ()
63- input = torch .randn (32 , 2 ) # Wrong channel
89+ pred = torch .randn (32 , 2 ) # Wrong channel
6490 target = torch .randint (0 , 2 , (32 , 1 )).float ()
6591 with self .assertRaises (ValueError ):
66- loss_fn (input , target )
92+ loss_fn (pred , target )
6793
6894 def test_invalid_target_shape (self ):
6995 """Test that invalid target shape raises ValueError."""
7096 loss_fn = AUCMLoss ()
71- input = torch .randn (32 , 1 )
97+ pred = torch .randn (32 , 1 )
7298 target = torch .randint (0 , 2 , (32 , 2 )).float () # Wrong channel
7399 with self .assertRaises (ValueError ):
74- loss_fn (input , target )
100+ loss_fn (pred , target )
75101
76102 def test_insufficient_dimensions (self ):
77103 """Test that tensors with insufficient dimensions raise ValueError."""
78104 loss_fn = AUCMLoss ()
79- input = torch .randn (32 ) # 1D tensor
105+ pred = torch .randn (32 ) # 1D tensor
80106 target = torch .randint (0 , 2 , (32 , 1 )).float ()
81107 with self .assertRaises (ValueError ):
82- loss_fn (input , target )
108+ loss_fn (pred , target )
83109
84110 def test_shape_mismatch (self ):
85111 """Test that mismatched shapes raise ValueError."""
86112 loss_fn = AUCMLoss ()
87- input = torch .randn (32 , 1 )
113+ pred = torch .randn (32 , 1 )
88114 target = torch .randint (0 , 2 , (16 , 1 )).float ()
89115 with self .assertRaises (ValueError ):
90- loss_fn (input , target )
116+ loss_fn (pred , target )
91117
92118 def test_non_binary_target (self ):
93119 """Test that non-binary target values raise ValueError."""
94120 loss_fn = AUCMLoss ()
95- input = torch .randn (32 , 1 )
121+ pred = torch .randn (32 , 1 )
96122 target = torch .tensor ([[0.5 ], [1.0 ], [2.0 ], [0.0 ]] * 8 ) # 32x1, still non-binary
97123 with self .assertRaises (ValueError ):
98- loss_fn (input , target )
124+ loss_fn (pred , target )
99125
100126 def test_backward (self ):
101127 """Test that gradients can be computed."""
102128 loss_fn = AUCMLoss ()
103- input = torch .randn (32 , 1 , requires_grad = True )
129+ pred = torch .randn (32 , 1 , requires_grad = True )
104130 target = torch .randint (0 , 2 , (32 , 1 )).float ()
105- loss = loss_fn (input , target )
131+ loss = loss_fn (pred , target )
106132 loss .backward ()
107- self .assertIsNotNone (input .grad )
133+ self .assertIsNotNone (pred .grad )
108134
109135 def test_script_save (self ):
110136 """Test that the loss can be saved as TorchScript."""
0 commit comments