Commit d0650f3
fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss (#8818)
### Description
This PR fixes an incorrect kernel initialization pattern in
`LocalNormalizedCrossCorrelationLoss` introduced in
`monai/losses/image_dissimilarity.py`.
**Bug found:** The original code used plain attribute assignment with a
typo:
```python
self.kernel = _kernel(self.kernel_size)
self.kernel.require_grads = False # typo: 'require_grads' is NOT a valid tensor attribute
self.kernel_vol = self.get_kernel_vol()
```
`require_grads` (with an 's') is not a valid PyTorch tensor attribute.
The correct attribute is `requires_grad`. This means the kernel was
**silently tracking gradients** throughout every forward pass, wasting
memory and computation.
Furthermore, assigning the kernel as a plain attribute (`self.kernel =
...`) rather than registering it as a buffer means:
- The kernel tensor does **not** automatically move to the correct
device when `.to(device)` / `.cuda()` / `.half()` is called on the loss
module
- The kernel is **not saved/restored** correctly via `state_dict()` /
`load_state_dict()`
**Fix:** Replace both assignments with `register_buffer`, which is the
correct PyTorch `nn.Module` pattern for constant tensors:
```python
self.register_buffer("kernel", _kernel(self.kernel_size))
self.register_buffer("kernel_vol", self.get_kernel_vol())
```
This removes the need for any manual `requires_grad` management and
ensures proper device placement.
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: Zeesejo <92383127+Zeesejo@users.noreply.github.com>
---------
Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com>
Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>1 parent 49b1012 commit d0650f3
1 file changed
Lines changed: 8 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
111 | 111 | | |
112 | 112 | | |
113 | 113 | | |
114 | | - | |
115 | | - | |
116 | | - | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
117 | 118 | | |
118 | 119 | | |
119 | 120 | | |
120 | 121 | | |
121 | | - | |
| 122 | + | |
| 123 | + | |
122 | 124 | | |
123 | 125 | | |
124 | 126 | | |
| |||
138 | 140 | | |
139 | 141 | | |
140 | 142 | | |
| 143 | + | |
| 144 | + | |
141 | 145 | | |
142 | 146 | | |
143 | 147 | | |
| |||
0 commit comments