1717import torch
1818from parameterized import parameterized
1919
20- from monai .losses .image_dissimilarity import LocalNormalizedCrossCorrelationLoss
20+ from monai .losses .image_dissimilarity import LocalNormalizedCrossCorrelationLoss , make_gaussian_kernel
2121
2222device = "cuda" if torch .cuda .is_available () else "cpu"
2323
113113 },
114114 - 0.95406944 ,
115115 ],
116+ # Regression tests for gh-8780: gaussian kernel_size > 3 was broken due to
117+ # truncated parameter being passed as pixel radius instead of sigma multiplier.
118+ # Identical images must yield loss == -1.0 for any kernel size.
119+ [
120+ {"spatial_dims" : 1 , "kernel_type" : "gaussian" , "kernel_size" : 5 },
121+ {
122+ "pred" : torch .arange (0 , 5 ).reshape (1 , 1 , - 1 ).to (dtype = torch .float , device = device ),
123+ "target" : torch .arange (0 , 5 ).reshape (1 , 1 , - 1 ).to (dtype = torch .float , device = device ),
124+ },
125+ - 1.0 ,
126+ ],
127+ [
128+ {"spatial_dims" : 1 , "kernel_type" : "gaussian" , "kernel_size" : 9 },
129+ {
130+ "pred" : torch .arange (0 , 9 ).reshape (1 , 1 , - 1 ).to (dtype = torch .float , device = device ),
131+ "target" : torch .arange (0 , 9 ).reshape (1 , 1 , - 1 ).to (dtype = torch .float , device = device ),
132+ },
133+ - 1.0 ,
134+ ],
116135]
117136
118137
@@ -138,6 +157,15 @@ def test_ill_shape(self):
138157 torch .ones ((1 , 3 , 4 , 4 , 4 ), dtype = torch .float , device = device ),
139158 )
140159
160+ def test_gaussian_kernel_shape_and_symmetry (self ):
161+ # gh-8780: kernel must have correct length, be symmetric, and peak at center
162+ for kernel_size in [3 , 5 , 7 , 9 , 11 , 15 ]:
163+ k = make_gaussian_kernel (kernel_size )
164+ self .assertEqual (len (k ), kernel_size )
165+ self .assertTrue (torch .allclose (k , k .flip (0 )), f"kernel_size={ kernel_size } not symmetric" )
166+ self .assertEqual (k .argmax ().item (), kernel_size // 2 )
167+ np .testing .assert_allclose (k .max ().item (), 1.0 , rtol = 1e-6 )
168+
141169 def test_ill_opts (self ):
142170 pred = torch .ones ((1 , 3 , 3 , 3 , 3 ), dtype = torch .float )
143171 target = torch .ones ((1 , 3 , 3 , 3 , 3 ), dtype = torch .float )
0 commit comments