Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 7896dda

Browse files
virginiafdezvirginiafdez
andauthored
Added SPADE functionality on the decode call for the sample methods. (#441)
* Added SPADE functionality on the decode call for the sample methods. * Changed the format of the partial statements. --------- Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk>
1 parent 18fef51 commit 7896dda

1 file changed

Lines changed: 65 additions & 66 deletions

File tree

generative/inferers/inferer.py

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,17 @@ def __call__(
381381
if self.ldm_latent_shape is not None:
382382
latent = self.ldm_resizer(latent)
383383

384-
call = partial(super().__call__, seg = seg) if \
385-
isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__
384+
call = super().__call__
385+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
386+
call = partial(super().__call__, seg=seg)
387+
386388
prediction = call(
387389
inputs=latent,
388390
diffusion_model=diffusion_model,
389391
noise=noise,
390392
timesteps=timesteps,
391393
condition=condition,
392-
mode=mode
394+
mode=mode,
393395
)
394396
return prediction
395397

@@ -432,9 +434,9 @@ def sample(
432434
"labels for each must be compatible. "
433435
)
434436

435-
sample = (
436-
partial(super().sample, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample
437-
)
437+
sample = super().sample
438+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
439+
sample = partial(super().sample, seg=seg)
438440

439441
outputs = sample(
440442
input_noise=input_noise,
@@ -456,19 +458,19 @@ def sample(
456458
latent = self.autoencoder_resizer(latent)
457459
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
458460

459-
image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)
461+
decode = autoencoder_model.decode_stage_2_outputs
462+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
463+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
464+
465+
image = decode(latent / self.scale_factor)
460466

461467
if save_intermediates:
462468
intermediates = []
463469
for latent_intermediate in latent_intermediates:
470+
decode = autoencoder_model.decode_stage_2_outputs
464471
if isinstance(autoencoder_model, SPADEAutoencoderKL):
465-
intermediates.append(
466-
autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor, seg=seg)
467-
)
468-
else:
469-
intermediates.append(
470-
autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)
471-
)
472+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
473+
intermediates.append(decode(latent_intermediate / self.scale_factor))
472474
return image, intermediates
473475

474476
else:
@@ -521,11 +523,9 @@ def get_likelihood(
521523
if self.ldm_latent_shape is not None:
522524
latents = self.ldm_resizer(latents)
523525

524-
get_likelihood = (
525-
partial(super().get_likelihood, seg=seg)
526-
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
527-
else super().get_likelihood
528-
)
526+
get_likelihood = super().get_likelihood
527+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
528+
get_likelihood = partial(super().get_likelihood, seg=seg)
529529

530530
outputs = get_likelihood(
531531
inputs=latents,
@@ -596,13 +596,11 @@ def __call__(
596596
noisy_image = torch.cat([noisy_image, condition], dim=1)
597597
condition = None
598598

599-
diffusion_model = (
600-
partial(diffusion_model, seg=seg)
601-
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
602-
else diffusion_model
603-
)
599+
diffuse = diffusion_model
600+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
601+
diffuse = partial(diffusion_model, seg = seg)
604602

605-
prediction = diffusion_model(
603+
prediction = diffuse(
606604
x=noisy_image,
607605
timesteps=timesteps,
608606
context=condition,
@@ -658,22 +656,21 @@ def sample(
658656
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
659657
)
660658
# 2. predict noise model_output
661-
diffusion_model = (
662-
partial(diffusion_model, seg=seg)
663-
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
664-
else diffusion_model
665-
)
659+
diffuse = diffusion_model
660+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
661+
diffuse = partial(diffusion_model, seg=seg)
662+
666663
if mode == "concat":
667664
model_input = torch.cat([image, conditioning], dim=1)
668-
model_output = diffusion_model(
665+
model_output = diffuse(
669666
model_input,
670667
timesteps=torch.Tensor((t,)).to(input_noise.device),
671668
context=None,
672669
down_block_additional_residuals=down_block_res_samples,
673670
mid_block_additional_residual=mid_block_res_sample,
674671
)
675672
else:
676-
model_output = diffusion_model(
673+
model_output = diffuse(
677674
image,
678675
timesteps=torch.Tensor((t,)).to(input_noise.device),
679676
context=conditioning,
@@ -747,22 +744,21 @@ def get_likelihood(
747744
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
748745
)
749746

750-
diffusion_model = (
751-
partial(diffusion_model, seg=seg)
752-
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
753-
else diffusion_model
754-
)
747+
diffuse = diffusion_model
748+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
749+
diffuse = partial(diffusion_model, seg = seg)
750+
755751
if mode == "concat":
756752
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
757-
model_output = diffusion_model(
753+
model_output = diffuse(
758754
noisy_image,
759755
timesteps=timesteps,
760756
context=None,
761757
down_block_additional_residuals=down_block_res_samples,
762758
mid_block_additional_residual=mid_block_res_sample,
763759
)
764760
else:
765-
model_output = diffusion_model(
761+
model_output = diffuse(
766762
x=noisy_image,
767763
timesteps=timesteps,
768764
context=conditioning,
@@ -836,7 +832,6 @@ def get_likelihood(
836832
else:
837833
return total_kl
838834

839-
840835
class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
841836
"""
842837
ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
@@ -905,8 +900,10 @@ def __call__(
905900
if cn_cond.shape[2:] != latent.shape[2:]:
906901
cn_cond = F.interpolate(cn_cond, latent.shape[2:])
907902

908-
call = partial(super().__call__, seg = seg) if \
909-
isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__
903+
call = super().__call__
904+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
905+
call = partial(super().__call__, seg=seg)
906+
910907
prediction = call(
911908
inputs=latent,
912909
diffusion_model=diffusion_model,
@@ -915,7 +912,7 @@ def __call__(
915912
timesteps=timesteps,
916913
cn_cond=cn_cond,
917914
condition=condition,
918-
mode=mode
915+
mode=mode,
919916
)
920917

921918
return prediction
@@ -966,20 +963,21 @@ def sample(
966963
if cn_cond.shape[2:] != input_noise.shape[2:]:
967964
cn_cond = F.interpolate(cn_cond, input_noise.shape[2:])
968965

969-
sample = partial(super().sample, seg = seg) if \
970-
isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample
966+
sample = super().sample
967+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
968+
sample = partial(super().sample, seg=seg)
971969

972970
outputs = sample(
973-
input_noise=input_noise,
974-
diffusion_model=diffusion_model,
975-
controlnet=controlnet,
976-
cn_cond=cn_cond,
977-
scheduler=scheduler,
978-
save_intermediates=save_intermediates,
979-
intermediate_steps=intermediate_steps,
980-
conditioning=conditioning,
981-
mode=mode,
982-
verbose=verbose,
971+
input_noise=input_noise,
972+
diffusion_model=diffusion_model,
973+
controlnet=controlnet,
974+
cn_cond=cn_cond,
975+
scheduler=scheduler,
976+
save_intermediates=save_intermediates,
977+
intermediate_steps=intermediate_steps,
978+
conditioning=conditioning,
979+
mode=mode,
980+
verbose=verbose,
983981
)
984982

985983
if save_intermediates:
@@ -991,19 +989,19 @@ def sample(
991989
latent = self.autoencoder_resizer(latent)
992990
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
993991

994-
image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)
992+
decode = autoencoder_model.decode_stage_2_outputs
993+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
994+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
995+
996+
image = decode(latent / self.scale_factor)
995997

996998
if save_intermediates:
997999
intermediates = []
9981000
for latent_intermediate in latent_intermediates:
1001+
decode = autoencoder_model.decode_stage_2_outputs
9991002
if isinstance(autoencoder_model, SPADEAutoencoderKL):
1000-
intermediates.append(
1001-
autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor), seg=seg
1002-
)
1003-
else:
1004-
intermediates.append(
1005-
autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)
1006-
)
1003+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1004+
intermediates.append(decode(latent_intermediate / self.scale_factor))
10071005
return image, intermediates
10081006

10091007
else:
@@ -1064,8 +1062,10 @@ def get_likelihood(
10641062
if self.ldm_latent_shape is not None:
10651063
latents = self.ldm_resizer(latents)
10661064

1067-
get_likelihood = partial(super().get_likelihood, seg = seg) if \
1068-
isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().get_likelihood
1065+
get_likelihood = super().get_likelihood
1066+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1067+
get_likelihood = partial(super().get_likelihood, seg=seg)
1068+
10691069
outputs = get_likelihood(
10701070
inputs=latents,
10711071
diffusion_model=diffusion_model,
@@ -1085,7 +1085,6 @@ def get_likelihood(
10851085
outputs = (outputs[0], intermediates)
10861086
return outputs
10871087

1088-
10891088
class VQVAETransformerInferer(Inferer):
10901089
"""
10911090
Class to perform inference with a VQVAE + Transformer model.

0 commit comments

Comments
 (0)