Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit e0e2559

Browse files
virginiafdezvirginiafdezmarksgraham
authored
Added SPADE-LDM code (#436)
* Added SPADE-LDM code: Modification of diffusion_model_unet to allow for SPADE normalisation to be set up as an option Modification of autoencoder_kl to allow for SPADE normalisation to be set up as an option Modification of inferer and latent inferer to allow for label to be passed through forward when SPADE is active Addition of tests to: test_spade_diffusion Creation of tutorial for 2D using OASIS subset of images. Even though I implemented tests, we should check very thoroughly that this works before merging, especially since the presence of SPADE norm needs for labels to be passed to the forward method, and ANY call of forward without a label if SPADE Is on will end up in error. In the same fashion, we should ensure that ANY call on forward when SPADE is not on is not disrupted (code doesn't error out because of a label missing). * Fetch tutorial from other PR * Made sure norm_params for SPADE had a single affine argument. * Code formatting. --------- Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk> Co-authored-by: Mark Graham <markgraham539@gmail.com>
1 parent 3da2673 commit e0e2559

12 files changed

Lines changed: 4882 additions & 171 deletions

generative/inferers/inferer.py

Lines changed: 143 additions & 47 deletions
Large diffs are not rendered by default.

generative/networks/nets/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
from .controlnet import ControlNet
1616
from .diffusion_model_unet import DiffusionModelUNet
1717
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
18+
from .spade_autoencoderkl import SPADEAutoencoderKL
19+
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
20+
from .spade_network import SPADENet
1821
from .transformer import DecoderOnlyTransformer
1922
from .vqvae import VQVAE

generative/networks/nets/diffusion_model_unet.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def __init__(
931931
cross_attention_dim: int | None = None,
932932
upcast_attention: bool = False,
933933
use_flash_attention: bool = False,
934-
dropout_cattn: float = 0.0
934+
dropout_cattn: float = 0.0,
935935
) -> None:
936936
super().__init__()
937937
self.resblock_updown = resblock_updown
@@ -964,7 +964,7 @@ def __init__(
964964
cross_attention_dim=cross_attention_dim,
965965
upcast_attention=upcast_attention,
966966
use_flash_attention=use_flash_attention,
967-
dropout=dropout_cattn
967+
dropout=dropout_cattn,
968968
)
969969
)
970970

@@ -1103,7 +1103,7 @@ def __init__(
11031103
cross_attention_dim: int | None = None,
11041104
upcast_attention: bool = False,
11051105
use_flash_attention: bool = False,
1106-
dropout_cattn: float = 0.0
1106+
dropout_cattn: float = 0.0,
11071107
) -> None:
11081108
super().__init__()
11091109
self.attention = None
@@ -1127,7 +1127,7 @@ def __init__(
11271127
cross_attention_dim=cross_attention_dim,
11281128
upcast_attention=upcast_attention,
11291129
use_flash_attention=use_flash_attention,
1130-
dropout=dropout_cattn
1130+
dropout=dropout_cattn,
11311131
)
11321132
self.resnet_2 = ResnetBlock(
11331133
spatial_dims=spatial_dims,
@@ -1271,7 +1271,7 @@ def __init__(
12711271
add_upsample: bool = True,
12721272
resblock_updown: bool = False,
12731273
num_head_channels: int = 1,
1274-
use_flash_attention: bool = False
1274+
use_flash_attention: bool = False,
12751275
) -> None:
12761276
super().__init__()
12771277
self.resblock_updown = resblock_updown
@@ -1388,7 +1388,7 @@ def __init__(
13881388
cross_attention_dim: int | None = None,
13891389
upcast_attention: bool = False,
13901390
use_flash_attention: bool = False,
1391-
dropout_cattn: float = 0.0
1391+
dropout_cattn: float = 0.0,
13921392
) -> None:
13931393
super().__init__()
13941394
self.resblock_updown = resblock_updown
@@ -1422,7 +1422,7 @@ def __init__(
14221422
cross_attention_dim=cross_attention_dim,
14231423
upcast_attention=upcast_attention,
14241424
use_flash_attention=use_flash_attention,
1425-
dropout=dropout_cattn
1425+
dropout=dropout_cattn,
14261426
)
14271427
)
14281428

@@ -1486,7 +1486,7 @@ def get_down_block(
14861486
cross_attention_dim: int | None,
14871487
upcast_attention: bool = False,
14881488
use_flash_attention: bool = False,
1489-
dropout_cattn: float = 0.0
1489+
dropout_cattn: float = 0.0,
14901490
) -> nn.Module:
14911491
if with_attn:
14921492
return AttnDownBlock(
@@ -1518,7 +1518,7 @@ def get_down_block(
15181518
cross_attention_dim=cross_attention_dim,
15191519
upcast_attention=upcast_attention,
15201520
use_flash_attention=use_flash_attention,
1521-
dropout_cattn=dropout_cattn
1521+
dropout_cattn=dropout_cattn,
15221522
)
15231523
else:
15241524
return DownBlock(
@@ -1546,7 +1546,7 @@ def get_mid_block(
15461546
cross_attention_dim: int | None,
15471547
upcast_attention: bool = False,
15481548
use_flash_attention: bool = False,
1549-
dropout_cattn: float = 0.0
1549+
dropout_cattn: float = 0.0,
15501550
) -> nn.Module:
15511551
if with_conditioning:
15521552
return CrossAttnMidBlock(
@@ -1560,7 +1560,7 @@ def get_mid_block(
15601560
cross_attention_dim=cross_attention_dim,
15611561
upcast_attention=upcast_attention,
15621562
use_flash_attention=use_flash_attention,
1563-
dropout_cattn=dropout_cattn
1563+
dropout_cattn=dropout_cattn,
15641564
)
15651565
else:
15661566
return AttnMidBlock(
@@ -1592,7 +1592,7 @@ def get_up_block(
15921592
cross_attention_dim: int | None,
15931593
upcast_attention: bool = False,
15941594
use_flash_attention: bool = False,
1595-
dropout_cattn: float = 0.0
1595+
dropout_cattn: float = 0.0,
15961596
) -> nn.Module:
15971597
if with_attn:
15981598
return AttnUpBlock(
@@ -1626,7 +1626,7 @@ def get_up_block(
16261626
cross_attention_dim=cross_attention_dim,
16271627
upcast_attention=upcast_attention,
16281628
use_flash_attention=use_flash_attention,
1629-
dropout_cattn=dropout_cattn
1629+
dropout_cattn=dropout_cattn,
16301630
)
16311631
else:
16321632
return UpBlock(
@@ -1688,7 +1688,7 @@ def __init__(
16881688
num_class_embeds: int | None = None,
16891689
upcast_attention: bool = False,
16901690
use_flash_attention: bool = False,
1691-
dropout_cattn: float = 0.0
1691+
dropout_cattn: float = 0.0,
16921692
) -> None:
16931693
super().__init__()
16941694
if with_conditioning is True and cross_attention_dim is None:
@@ -1701,9 +1701,7 @@ def __init__(
17011701
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
17021702
)
17031703
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
1704-
raise ValueError(
1705-
"Dropout cannot be negative or >1.0!"
1706-
)
1704+
raise ValueError("Dropout cannot be negative or >1.0!")
17071705

17081706
# All number of channels should be multiple of num_groups
17091707
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
@@ -1793,7 +1791,7 @@ def __init__(
17931791
cross_attention_dim=cross_attention_dim,
17941792
upcast_attention=upcast_attention,
17951793
use_flash_attention=use_flash_attention,
1796-
dropout_cattn=dropout_cattn
1794+
dropout_cattn=dropout_cattn,
17971795
)
17981796

17991797
self.down_blocks.append(down_block)
@@ -1811,7 +1809,7 @@ def __init__(
18111809
cross_attention_dim=cross_attention_dim,
18121810
upcast_attention=upcast_attention,
18131811
use_flash_attention=use_flash_attention,
1814-
dropout_cattn=dropout_cattn
1812+
dropout_cattn=dropout_cattn,
18151813
)
18161814

18171815
# up
@@ -1846,7 +1844,7 @@ def __init__(
18461844
cross_attention_dim=cross_attention_dim,
18471845
upcast_attention=upcast_attention,
18481846
use_flash_attention=use_flash_attention,
1849-
dropout_cattn=dropout_cattn
1847+
dropout_cattn=dropout_cattn,
18501848
)
18511849

18521850
self.up_blocks.append(up_block)

0 commit comments

Comments
 (0)