1515from abc import ABC , abstractmethod
1616from collections .abc import Hashable , Mapping
1717from copy import deepcopy
18- from typing import Any
18+ from typing import Any , cast
1919
2020import numpy as np
2121import torch
@@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470 start = time .time ()
471471 image_tensor = d [self .image_key ]
472472 label_tensor = d [self .label_key ]
473+ # Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
473474 using_cuda = any (
474475 isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda" for t in (image_tensor , label_tensor )
475476 )
@@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
480481 label_tensor , (MetaTensor , torch .Tensor )
481482 ):
482483 if label_tensor .device != image_tensor .device :
483- label_tensor = label_tensor .to (image_tensor .device ) # type: ignore
484+ if using_cuda :
485+ # Move both tensors to CUDA when mixing devices
486+ cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
487+ image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
488+ label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
489+ else :
490+ label_tensor = cast (MetaTensor , label_tensor .to (image_tensor .device ))
484491
485492 ndas : list [MetaTensor ] = [image_tensor [i ] for i in range (image_tensor .shape [0 ])] # type: ignore
486493 ndas_label : MetaTensor = label_tensor .astype (torch .int16 ) # (H,W,D)
0 commit comments