|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 | import torch.nn.functional as F |
| 21 | +from monai.data import decollate_batch |
21 | 22 | from monai.inferers import Inferer |
22 | 23 | from monai.transforms import CenterSpatialCrop, SpatialPad |
23 | 24 | from monai.utils import optional_import |
@@ -348,8 +349,8 @@ def __init__( |
348 | 349 | self.ldm_latent_shape = ldm_latent_shape |
349 | 350 | self.autoencoder_latent_shape = autoencoder_latent_shape |
350 | 351 | if self.ldm_latent_shape is not None: |
351 | | - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) |
352 | | - self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) |
| 352 | + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) |
| 353 | + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) |
353 | 354 |
|
354 | 355 | def __call__( |
355 | 356 | self, |
@@ -379,7 +380,7 @@ def __call__( |
379 | 380 | latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor |
380 | 381 |
|
381 | 382 | if self.ldm_latent_shape is not None: |
382 | | - latent = self.ldm_resizer(latent) |
| 383 | + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) |
383 | 384 |
|
384 | 385 | call = super().__call__ |
385 | 386 | if isinstance(diffusion_model, SPADEDiffusionModelUNet): |
@@ -454,14 +455,15 @@ def sample( |
454 | 455 | else: |
455 | 456 | latent = outputs |
456 | 457 |
|
457 | | - if self.ldm_latent_shape is not None: |
458 | | - latent = self.autoencoder_resizer(latent) |
459 | | - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] |
| 458 | + if self.autoencoder_latent_shape is not None: |
| 459 | + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) |
| 460 | + latent_intermediates = [ |
| 461 | + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates |
| 462 | + ] |
460 | 463 |
|
461 | 464 | decode = autoencoder_model.decode_stage_2_outputs |
462 | 465 | if isinstance(autoencoder_model, SPADEAutoencoderKL): |
463 | 466 | decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) |
464 | | - |
465 | 467 | image = decode(latent / self.scale_factor) |
466 | 468 |
|
467 | 469 | if save_intermediates: |
@@ -521,7 +523,7 @@ def get_likelihood( |
521 | 523 | latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor |
522 | 524 |
|
523 | 525 | if self.ldm_latent_shape is not None: |
524 | | - latents = self.ldm_resizer(latents) |
| 526 | + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) |
525 | 527 |
|
526 | 528 | get_likelihood = super().get_likelihood |
527 | 529 | if isinstance(diffusion_model, SPADEDiffusionModelUNet): |
@@ -862,7 +864,7 @@ def __init__( |
862 | 864 | self.ldm_latent_shape = ldm_latent_shape |
863 | 865 | self.autoencoder_latent_shape = autoencoder_latent_shape |
864 | 866 | if self.ldm_latent_shape is not None: |
865 | | - self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) |
| 867 | + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) |
866 | 868 | self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) |
867 | 869 |
|
868 | 870 | def __call__( |
@@ -897,7 +899,8 @@ def __call__( |
897 | 899 | latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor |
898 | 900 |
|
899 | 901 | if self.ldm_latent_shape is not None: |
900 | | - latent = self.ldm_resizer(latent) |
| 902 | + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) |
| 903 | + |
901 | 904 | if cn_cond.shape[2:] != latent.shape[2:]: |
902 | 905 | cn_cond = F.interpolate(cn_cond, latent.shape[2:]) |
903 | 906 |
|
@@ -986,9 +989,11 @@ def sample( |
986 | 989 | else: |
987 | 990 | latent = outputs |
988 | 991 |
|
989 | | - if self.ldm_latent_shape is not None: |
990 | | - latent = self.autoencoder_resizer(latent) |
991 | | - latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] |
| 992 | + if self.autoencoder_latent_shape is not None: |
| 993 | + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) |
| 994 | + latent_intermediates = [ |
| 995 | + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates |
| 996 | + ] |
992 | 997 |
|
993 | 998 | decode = autoencoder_model.decode_stage_2_outputs |
994 | 999 | if isinstance(autoencoder_model, SPADEAutoencoderKL): |
@@ -1061,7 +1066,7 @@ def get_likelihood( |
1061 | 1066 | cn_cond = F.interpolate(cn_cond, latents.shape[2:]) |
1062 | 1067 |
|
1063 | 1068 | if self.ldm_latent_shape is not None: |
1064 | | - latents = self.ldm_resizer(latents) |
| 1069 | + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) |
1065 | 1070 |
|
1066 | 1071 | get_likelihood = super().get_likelihood |
1067 | 1072 | if isinstance(diffusion_model, SPADEDiffusionModelUNet): |
|
0 commit comments