Skip to content

Commit d0650f3

Browse files
Zeesejoericspod
andauthored
fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss (#8818)
### Description This PR fixes an incorrect kernel initialization pattern in `LocalNormalizedCrossCorrelationLoss` introduced in `monai/losses/image_dissimilarity.py`. **Bug found:** The original code used plain attribute assignment with a typo: ```python self.kernel = _kernel(self.kernel_size) self.kernel.require_grads = False # typo: 'require_grads' is NOT a valid tensor attribute self.kernel_vol = self.get_kernel_vol() ``` `require_grads` (with an 's') is not a valid PyTorch tensor attribute. The correct attribute is `requires_grad`. This means the kernel was **silently tracking gradients** throughout every forward pass, wasting memory and computation. Furthermore, assigning the kernel as a plain attribute (`self.kernel = ...`) rather than registering it as a buffer means: - The kernel tensor does **not** automatically move to the correct device when `.to(device)` / `.cuda()` / `.half()` is called on the loss module - The kernel is **not saved/restored** correctly via `state_dict()` / `load_state_dict()` **Fix:** Replace both assignments with `register_buffer`, which is the correct PyTorch `nn.Module` pattern for constant tensors: ```python self.register_buffer("kernel", _kernel(self.kernel_size)) self.register_buffer("kernel_vol", self.get_kernel_vol()) ``` This removes the need for any manual `requires_grad` management and ensures proper device placement. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Zeesejo <92383127+Zeesejo@users.noreply.github.com> --------- Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 49b1012 commit d0650f3

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

monai/losses/image_dissimilarity.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,16 @@ def __init__(
111111
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
112112

113113
_kernel = look_up_option(kernel_type, kernel_dict)
114-
self.kernel = _kernel(self.kernel_size)
115-
self.kernel.require_grads = False
116-
self.kernel_vol = self.get_kernel_vol()
114+
self.kernel: torch.Tensor
115+
self.kernel_vol: torch.Tensor
116+
self.register_buffer("kernel", _kernel(self.kernel_size), persistent=False)
117+
self.register_buffer("kernel_vol", self.get_kernel_vol(), persistent=False)
117118

118119
self.smooth_nr = float(smooth_nr)
119120
self.smooth_dr = float(smooth_dr)
120121

121-
def get_kernel_vol(self):
122+
def get_kernel_vol(self) -> torch.Tensor:
123+
assert self.kernel is not None
122124
vol = self.kernel
123125
for _ in range(self.ndim - 1):
124126
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
@@ -138,6 +140,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
138140
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
139141

140142
t2, p2, tp = target * target, pred * pred, target * pred
143+
assert self.kernel is not None
144+
assert self.kernel_vol is not None
141145
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
142146
kernels = [kernel] * self.ndim
143147
# sum over kernel

0 commit comments

Comments
 (0)