Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit a3762b9

Browse files
virginiafdezvirginiafdezmarksgraham
authored
Issue was coming from the definition of SpatialPad (self.ldm_resizer)… (#449)
* Issue was coming from the definition of SpatialPad (self.ldm_resizer) and Crop (self.autoencoder_resizer). The spatial_size passed included a [-1] to cover the channel dimension. The code, as it was written, made the assumption that this channel dimension was a spatial dimension and that the batch dimension was the channel one, leading to errors related to the affine transform of the MetaTensor being wrong. self.ldm_resizer should operate on an unbatched version of the tensor, hence we changed the call to the resizers by one that uses decollate and then stacks the elements of the batch together again. * Formatting --------- Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk> Co-authored-by: Mark Graham <markgraham539@gmail.com>
1 parent 5743fa2 commit a3762b9

1 file changed

Lines changed: 19 additions & 14 deletions

File tree

generative/inferers/inferer.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
from monai.data import decollate_batch
2122
from monai.inferers import Inferer
2223
from monai.transforms import CenterSpatialCrop, SpatialPad
2324
from monai.utils import optional_import
@@ -348,8 +349,8 @@ def __init__(
348349
self.ldm_latent_shape = ldm_latent_shape
349350
self.autoencoder_latent_shape = autoencoder_latent_shape
350351
if self.ldm_latent_shape is not None:
351-
self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape)
352-
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)
352+
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
353+
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
353354

354355
def __call__(
355356
self,
@@ -379,7 +380,7 @@ def __call__(
379380
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
380381

381382
if self.ldm_latent_shape is not None:
382-
latent = self.ldm_resizer(latent)
383+
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
383384

384385
call = super().__call__
385386
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
@@ -454,14 +455,15 @@ def sample(
454455
else:
455456
latent = outputs
456457

457-
if self.ldm_latent_shape is not None:
458-
latent = self.autoencoder_resizer(latent)
459-
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
458+
if self.autoencoder_latent_shape is not None:
459+
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
460+
latent_intermediates = [
461+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
462+
]
460463

461464
decode = autoencoder_model.decode_stage_2_outputs
462465
if isinstance(autoencoder_model, SPADEAutoencoderKL):
463466
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
464-
465467
image = decode(latent / self.scale_factor)
466468

467469
if save_intermediates:
@@ -521,7 +523,7 @@ def get_likelihood(
521523
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
522524

523525
if self.ldm_latent_shape is not None:
524-
latents = self.ldm_resizer(latents)
526+
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
525527

526528
get_likelihood = super().get_likelihood
527529
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
@@ -862,7 +864,7 @@ def __init__(
862864
self.ldm_latent_shape = ldm_latent_shape
863865
self.autoencoder_latent_shape = autoencoder_latent_shape
864866
if self.ldm_latent_shape is not None:
865-
self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape)
867+
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
866868
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)
867869

868870
def __call__(
@@ -897,7 +899,8 @@ def __call__(
897899
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
898900

899901
if self.ldm_latent_shape is not None:
900-
latent = self.ldm_resizer(latent)
902+
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
903+
901904
if cn_cond.shape[2:] != latent.shape[2:]:
902905
cn_cond = F.interpolate(cn_cond, latent.shape[2:])
903906

@@ -986,9 +989,11 @@ def sample(
986989
else:
987990
latent = outputs
988991

989-
if self.ldm_latent_shape is not None:
990-
latent = self.autoencoder_resizer(latent)
991-
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
992+
if self.autoencoder_latent_shape is not None:
993+
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
994+
latent_intermediates = [
995+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
996+
]
992997

993998
decode = autoencoder_model.decode_stage_2_outputs
994999
if isinstance(autoencoder_model, SPADEAutoencoderKL):
@@ -1061,7 +1066,7 @@ def get_likelihood(
10611066
cn_cond = F.interpolate(cn_cond, latents.shape[2:])
10621067

10631068
if self.ldm_latent_shape is not None:
1064-
latents = self.ldm_resizer(latents)
1069+
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
10651070

10661071
get_likelihood = super().get_likelihood
10671072
if isinstance(diffusion_model, SPADEDiffusionModelUNet):

0 commit comments

Comments
 (0)