@@ -381,15 +381,17 @@ def __call__(
381381 if self .ldm_latent_shape is not None :
382382 latent = self .ldm_resizer (latent )
383383
384- call = partial (super ().__call__ , seg = seg ) if \
385- isinstance (diffusion_model , SPADEDiffusionModelUNet ) else super ().__call__
384+ call = super ().__call__
385+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
386+ call = partial (super ().__call__ , seg = seg )
387+
386388 prediction = call (
387389 inputs = latent ,
388390 diffusion_model = diffusion_model ,
389391 noise = noise ,
390392 timesteps = timesteps ,
391393 condition = condition ,
392- mode = mode
394+ mode = mode ,
393395 )
394396 return prediction
395397
@@ -432,9 +434,9 @@ def sample(
432434 "labels for each must be compatible. "
433435 )
434436
435- sample = (
436- partial ( super (). sample , seg = seg ) if isinstance (diffusion_model , SPADEDiffusionModelUNet ) else super (). sample
437- )
437+ sample = super (). sample
438+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
439+ sample = partial ( super (). sample , seg = seg )
438440
439441 outputs = sample (
440442 input_noise = input_noise ,
@@ -456,19 +458,19 @@ def sample(
456458 latent = self .autoencoder_resizer (latent )
457459 latent_intermediates = [self .autoencoder_resizer (l ) for l in latent_intermediates ]
458460
459- image = autoencoder_model .decode_stage_2_outputs (latent / self .scale_factor )
461+ decode = autoencoder_model .decode_stage_2_outputs
462+ if isinstance (autoencoder_model , SPADEAutoencoderKL ):
463+ decode = partial (autoencoder_model .decode_stage_2_outputs , seg = seg )
464+
465+ image = decode (latent / self .scale_factor )
460466
461467 if save_intermediates :
462468 intermediates = []
463469 for latent_intermediate in latent_intermediates :
470+ decode = autoencoder_model .decode_stage_2_outputs
464471 if isinstance (autoencoder_model , SPADEAutoencoderKL ):
465- intermediates .append (
466- autoencoder_model .decode_stage_2_outputs (latent_intermediate / self .scale_factor , seg = seg )
467- )
468- else :
469- intermediates .append (
470- autoencoder_model .decode_stage_2_outputs (latent_intermediate / self .scale_factor )
471- )
472+ decode = partial (autoencoder_model .decode_stage_2_outputs , seg = seg )
473+ intermediates .append (decode (latent_intermediate / self .scale_factor ))
472474 return image , intermediates
473475
474476 else :
@@ -521,11 +523,9 @@ def get_likelihood(
521523 if self .ldm_latent_shape is not None :
522524 latents = self .ldm_resizer (latents )
523525
524- get_likelihood = (
525- partial (super ().get_likelihood , seg = seg )
526- if isinstance (diffusion_model , SPADEDiffusionModelUNet )
527- else super ().get_likelihood
528- )
526+ get_likelihood = super ().get_likelihood
527+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
528+ get_likelihood = partial (super ().get_likelihood , seg = seg )
529529
530530 outputs = get_likelihood (
531531 inputs = latents ,
@@ -596,13 +596,11 @@ def __call__(
596596 noisy_image = torch .cat ([noisy_image , condition ], dim = 1 )
597597 condition = None
598598
599- diffusion_model = (
600- partial (diffusion_model , seg = seg )
601- if isinstance (diffusion_model , SPADEDiffusionModelUNet )
602- else diffusion_model
603- )
599+ diffuse = diffusion_model
600+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
601+ diffuse = partial (diffusion_model , seg = seg )
604602
605- prediction = diffusion_model (
603+ prediction = diffuse (
606604 x = noisy_image ,
607605 timesteps = timesteps ,
608606 context = condition ,
@@ -658,22 +656,21 @@ def sample(
658656 x = image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), controlnet_cond = cn_cond
659657 )
660658 # 2. predict noise model_output
661- diffusion_model = (
662- partial (diffusion_model , seg = seg )
663- if isinstance (diffusion_model , SPADEDiffusionModelUNet )
664- else diffusion_model
665- )
659+ diffuse = diffusion_model
660+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
661+ diffuse = partial (diffusion_model , seg = seg )
662+
666663 if mode == "concat" :
667664 model_input = torch .cat ([image , conditioning ], dim = 1 )
668- model_output = diffusion_model (
665+ model_output = diffuse (
669666 model_input ,
670667 timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
671668 context = None ,
672669 down_block_additional_residuals = down_block_res_samples ,
673670 mid_block_additional_residual = mid_block_res_sample ,
674671 )
675672 else :
676- model_output = diffusion_model (
673+ model_output = diffuse (
677674 image ,
678675 timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
679676 context = conditioning ,
@@ -747,22 +744,21 @@ def get_likelihood(
747744 x = noisy_image , timesteps = torch .Tensor ((t ,)).to (inputs .device ), controlnet_cond = cn_cond
748745 )
749746
750- diffusion_model = (
751- partial (diffusion_model , seg = seg )
752- if isinstance (diffusion_model , SPADEDiffusionModelUNet )
753- else diffusion_model
754- )
747+ diffuse = diffusion_model
748+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
749+ diffuse = partial (diffusion_model , seg = seg )
750+
755751 if mode == "concat" :
756752 noisy_image = torch .cat ([noisy_image , conditioning ], dim = 1 )
757- model_output = diffusion_model (
753+ model_output = diffuse (
758754 noisy_image ,
759755 timesteps = timesteps ,
760756 context = None ,
761757 down_block_additional_residuals = down_block_res_samples ,
762758 mid_block_additional_residual = mid_block_res_sample ,
763759 )
764760 else :
765- model_output = diffusion_model (
761+ model_output = diffuse (
766762 x = noisy_image ,
767763 timesteps = timesteps ,
768764 context = conditioning ,
@@ -836,7 +832,6 @@ def get_likelihood(
836832 else :
837833 return total_kl
838834
839-
840835class ControlNetLatentDiffusionInferer (ControlNetDiffusionInferer ):
841836 """
842837 ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
@@ -905,8 +900,10 @@ def __call__(
905900 if cn_cond .shape [2 :] != latent .shape [2 :]:
906901 cn_cond = F .interpolate (cn_cond , latent .shape [2 :])
907902
908- call = partial (super ().__call__ , seg = seg ) if \
909- isinstance (diffusion_model , SPADEDiffusionModelUNet ) else super ().__call__
903+ call = super ().__call__
904+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
905+ call = partial (super ().__call__ , seg = seg )
906+
910907 prediction = call (
911908 inputs = latent ,
912909 diffusion_model = diffusion_model ,
@@ -915,7 +912,7 @@ def __call__(
915912 timesteps = timesteps ,
916913 cn_cond = cn_cond ,
917914 condition = condition ,
918- mode = mode
915+ mode = mode ,
919916 )
920917
921918 return prediction
@@ -966,20 +963,21 @@ def sample(
966963 if cn_cond .shape [2 :] != input_noise .shape [2 :]:
967964 cn_cond = F .interpolate (cn_cond , input_noise .shape [2 :])
968965
969- sample = partial (super ().sample , seg = seg ) if \
970- isinstance (diffusion_model , SPADEDiffusionModelUNet ) else super ().sample
966+ sample = super ().sample
967+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
968+ sample = partial (super ().sample , seg = seg )
971969
972970 outputs = sample (
973- input_noise = input_noise ,
974- diffusion_model = diffusion_model ,
975- controlnet = controlnet ,
976- cn_cond = cn_cond ,
977- scheduler = scheduler ,
978- save_intermediates = save_intermediates ,
979- intermediate_steps = intermediate_steps ,
980- conditioning = conditioning ,
981- mode = mode ,
982- verbose = verbose ,
971+ input_noise = input_noise ,
972+ diffusion_model = diffusion_model ,
973+ controlnet = controlnet ,
974+ cn_cond = cn_cond ,
975+ scheduler = scheduler ,
976+ save_intermediates = save_intermediates ,
977+ intermediate_steps = intermediate_steps ,
978+ conditioning = conditioning ,
979+ mode = mode ,
980+ verbose = verbose ,
983981 )
984982
985983 if save_intermediates :
@@ -991,19 +989,19 @@ def sample(
991989 latent = self .autoencoder_resizer (latent )
992990 latent_intermediates = [self .autoencoder_resizer (l ) for l in latent_intermediates ]
993991
994- image = autoencoder_model .decode_stage_2_outputs (latent / self .scale_factor )
992+ decode = autoencoder_model .decode_stage_2_outputs
993+ if isinstance (autoencoder_model , SPADEAutoencoderKL ):
994+ decode = partial (autoencoder_model .decode_stage_2_outputs , seg = seg )
995+
996+ image = decode (latent / self .scale_factor )
995997
996998 if save_intermediates :
997999 intermediates = []
9981000 for latent_intermediate in latent_intermediates :
1001+ decode = autoencoder_model .decode_stage_2_outputs
9991002 if isinstance (autoencoder_model , SPADEAutoencoderKL ):
1000- intermediates .append (
1001- autoencoder_model .decode_stage_2_outputs (latent_intermediate / self .scale_factor ), seg = seg
1002- )
1003- else :
1004- intermediates .append (
1005- autoencoder_model .decode_stage_2_outputs (latent_intermediate / self .scale_factor )
1006- )
1003+ decode = partial (autoencoder_model .decode_stage_2_outputs , seg = seg )
1004+ intermediates .append (decode (latent_intermediate / self .scale_factor ))
10071005 return image , intermediates
10081006
10091007 else :
@@ -1064,8 +1062,10 @@ def get_likelihood(
10641062 if self .ldm_latent_shape is not None :
10651063 latents = self .ldm_resizer (latents )
10661064
1067- get_likelihood = partial (super ().get_likelihood , seg = seg ) if \
1068- isinstance (diffusion_model , SPADEDiffusionModelUNet ) else super ().get_likelihood
1065+ get_likelihood = super ().get_likelihood
1066+ if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
1067+ get_likelihood = partial (super ().get_likelihood , seg = seg )
1068+
10691069 outputs = get_likelihood (
10701070 inputs = latents ,
10711071 diffusion_model = diffusion_model ,
@@ -1085,7 +1085,6 @@ def get_likelihood(
10851085 outputs = (outputs [0 ], intermediates )
10861086 return outputs
10871087
1088-
10891088class VQVAETransformerInferer (Inferer ):
10901089 """
10911090 Class to perform inference with a VQVAE + Transformer model.
0 commit comments