Skip to content

Commit 5ddc43e

Browse files
Merge branch 'dev' into patch-4
2 parents 94a465b + af361cc commit 5ddc43e

3 files changed

Lines changed: 15 additions & 7 deletions

File tree

monai/auto3dseg/analyzer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Hashable, Mapping
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919

2020
import numpy as np
2121
import torch
@@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470
start = time.time()
471471
image_tensor = d[self.image_key]
472472
label_tensor = d[self.label_key]
473+
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
473474
using_cuda = any(
474475
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475476
)
@@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
480481
label_tensor, (MetaTensor, torch.Tensor)
481482
):
482483
if label_tensor.device != image_tensor.device:
483-
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
484+
if using_cuda:
485+
# Move both tensors to CUDA when mixing devices
486+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489+
else:
490+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
484491

485492
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486493
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

tests/apps/test_auto3dseg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
393393
result = analyzer({"image": image_tensor, "label": label_tensor})
394394
report = result["label_stats"]
395395

396+
# Verify report format and computation succeeded despite mixed/unified devices
396397
assert verify_report_format(report, analyzer.get_report_format())
397398
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]
398399

tests/metrics/test_ssim_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class TestSSIMMetric(unittest.TestCase):
2323

24-
def test2d_gaussian(self):
24+
def test_2d_gaussian(self):
2525
set_determinism(0)
2626
preds = torch.abs(torch.randn(2, 3, 16, 16))
2727
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -32,9 +32,9 @@ def test2d_gaussian(self):
3232
metric(preds, target)
3333
result = metric.aggregate()
3434
expected_value = 0.045415
35-
self.assertTrue(expected_value - result.item() < 0.000001)
35+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
3636

37-
def test2d_uniform(self):
37+
def test_2d_uniform(self):
3838
set_determinism(0)
3939
preds = torch.abs(torch.randn(2, 3, 16, 16))
4040
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -45,9 +45,9 @@ def test2d_uniform(self):
4545
metric(preds, target)
4646
result = metric.aggregate()
4747
expected_value = 0.050103
48-
self.assertTrue(expected_value - result.item() < 0.000001)
48+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
4949

50-
def test3d_gaussian(self):
50+
def test_3d_gaussian(self):
5151
set_determinism(0)
5252
preds = torch.abs(torch.randn(2, 3, 16, 16, 16))
5353
target = torch.abs(torch.randn(2, 3, 16, 16, 16))

0 commit comments

Comments
 (0)