Skip to content

Commit e2b99da

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-pytorch-release-2508-on-series-50
2 parents e57f64e + 853f702 commit e2b99da

26 files changed

Lines changed: 1286 additions & 82 deletions

.github/workflows/setupapp.yml

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
runs-on: ubuntu-latest
8282
strategy:
8383
matrix:
84-
python-version: ['3.9', '3.10', '3.11']
84+
python-version: ['3.10', '3.11', '3.12']
8585
steps:
8686
- uses: actions/checkout@v6
8787
with:
@@ -90,23 +90,12 @@ jobs:
9090
uses: actions/setup-python@v6
9191
with:
9292
python-version: ${{ matrix.python-version }}
93-
- name: cache weekly timestamp
94-
id: pip-cache
95-
run: |
96-
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
97-
- name: cache for pip
98-
uses: actions/cache@v5
99-
id: cache
100-
with:
101-
path: |
102-
~/.cache/pip
103-
~/.cache/torch
104-
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }}
93+
cache: pip
10594
- name: Install the dependencies
10695
run: |
10796
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
10897
python -m pip install --upgrade pip wheel
109-
python -m pip install -r requirements-dev.txt
98+
python -m pip install --no-build-isolation -r requirements-dev.txt
11099
- name: Run quick tests CPU ubuntu
111100
env:
112101
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
@@ -115,8 +104,8 @@ jobs:
115104
run: |
116105
python -m pip list
117106
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
118-
BUILD_MONAI=0 ./runtests.sh --build --quick --unittests
119-
BUILD_MONAI=1 ./runtests.sh --build --quick --min
107+
BUILD_MONAI=0 ./runtests.sh --build --coverage --quick --unittests
108+
BUILD_MONAI=1 ./runtests.sh --build --coverage --quick --min
120109
coverage xml --ignore-errors
121110
- name: Upload coverage
122111
uses: codecov/codecov-action@v5

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ Segmentation Losses
9898
.. autoclass:: NACLLoss
9999
:members:
100100

101+
`MCCLoss`
102+
~~~~~~~~~
103+
.. autoclass:: MCCLoss
104+
:members:
105+
101106
Registration Losses
102107
-------------------
103108

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# If meta_key references a MetaTensor, extract from its .meta attribute;
86+
# otherwise treat it as a metadata dictionary directly.
87+
if isinstance(meta_obj, MetaTensor):
88+
meta_dict: dict = meta_obj.meta
89+
else:
90+
meta_dict = dict(meta_obj)
91+
6192
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
93+
if key in meta_dict:
94+
d[key] = meta_dict[key] # type: ignore
6495
elif not self.allow_missing_keys:
6596
raise KeyError(
6697
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"

monai/auto3dseg/analyzer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
468468
"""
469469
d: dict[Hashable, MetaTensor] = dict(data)
470470
start = time.time()
471-
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
472-
using_cuda = True
473-
else:
474-
using_cuda = False
471+
image_tensor = d[self.image_key]
472+
label_tensor = d[self.label_key]
473+
using_cuda = any(
474+
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475+
)
475476
restore_grad_state = torch.is_grad_enabled()
476477
torch.set_grad_enabled(False)
477478

478-
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
479-
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
479+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
480+
label_tensor, (MetaTensor, torch.Tensor)
481+
):
482+
if label_tensor.device != image_tensor.device:
483+
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
484+
485+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
480487

481488
if ndas_label.shape != ndas[0].shape:
482489
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
483490

484491
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
485-
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
492+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
486493

487494
unique_label = unique(ndas_label)
488495
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):

monai/data/image_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def convert_to_channel_last(
324324
data = data[..., 0, :]
325325
# if desired, remove trailing singleton dimensions
326326
while squeeze_end_dims and data.shape[-1] == 1:
327-
data = np.squeeze(data, -1)
327+
data = data.squeeze(-1)
328328
if contiguous:
329329
data = ascontiguousarray(data)
330330
return data

monai/engines/trainer.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ class SupervisedTrainer(Trainer):
131131
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
132132
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
133133
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
134+
accumulation_steps: number of mini-batches over which to accumulate gradients before
135+
calling ``optimizer.step()``, effectively simulating a larger batch size on
136+
memory-constrained hardware. Must be a positive integer. Default: 1 (no accumulation).
137+
When ``epoch_length`` is known and not divisible by ``accumulation_steps``, a flush
138+
(optimizer step) is performed at the end of each epoch so no gradients are silently
139+
discarded. The loss stored in ``engine.state.output`` is always the **unscaled** value.
134140
"""
135141

136142
def __init__(
@@ -160,7 +166,10 @@ def __init__(
160166
amp_kwargs: dict | None = None,
161167
compile: bool = False,
162168
compile_kwargs: dict | None = None,
169+
accumulation_steps: int = 1,
163170
) -> None:
171+
if accumulation_steps < 1:
172+
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
164173
super().__init__(
165174
device=device,
166175
max_epochs=max_epochs,
@@ -190,6 +199,7 @@ def __init__(
190199
self.loss_function = loss_function
191200
self.inferer = SimpleInferer() if inferer is None else inferer
192201
self.optim_set_to_none = optim_set_to_none
202+
self.accumulation_steps = accumulation_steps
193203

194204
def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
195205
"""
@@ -245,21 +255,42 @@ def _compute_pred_loss():
245255
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
246256
engine.fire_event(IterationEvents.LOSS_COMPLETED)
247257

258+
# Determine gradient accumulation state
259+
acc = engine.accumulation_steps
260+
if acc > 1:
261+
epoch_length = engine.state.epoch_length
262+
if epoch_length is not None:
263+
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
264+
should_zero_grad = local_iter % acc == 0
265+
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
266+
else:
267+
local_iter = engine.state.iteration - 1 # 0-indexed global
268+
should_zero_grad = local_iter % acc == 0
269+
should_step = (local_iter + 1) % acc == 0
270+
else:
271+
should_zero_grad = True
272+
should_step = True
273+
248274
engine.network.train()
249-
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
275+
if should_zero_grad:
276+
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
250277

251278
if engine.amp and engine.scaler is not None:
252279
with torch.autocast("cuda", **engine.amp_kwargs):
253280
_compute_pred_loss()
254-
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
281+
loss = engine.state.output[Keys.LOSS]
282+
engine.scaler.scale(loss / acc if acc > 1 else loss).backward()
255283
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
256-
engine.scaler.step(engine.optimizer)
257-
engine.scaler.update()
284+
if should_step:
285+
engine.scaler.step(engine.optimizer)
286+
engine.scaler.update()
258287
else:
259288
_compute_pred_loss()
260-
engine.state.output[Keys.LOSS].backward()
289+
loss = engine.state.output[Keys.LOSS]
290+
(loss / acc if acc > 1 else loss).backward()
261291
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
262-
engine.optimizer.step()
292+
if should_step:
293+
engine.optimizer.step()
263294
# copy back meta info
264295
if self.compile:
265296
if inputs_meta is not None:

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .giou_loss import BoxGIoULoss, giou
3737
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
3838
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
39+
from .mcc_loss import MCCLoss
3940
from .multi_scale import MultiScaleLoss
4041
from .nacl_loss import NACLLoss
4142
from .perceptual import PerceptualLoss

0 commit comments

Comments
 (0)