Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 3da2673

Browse files
virginiafdezvirginiafdezmarksgraham
authored
Add dropout for conditioning cross-attention blocks. (#407)
* Add dropout for conditioning cross-attention blocks. * Removed test_dropout. Included tests for this add-on (dropout possibility in cross-attention blocks) in the main test_diffusion_model_unet. * Update tests/test_diffusion_model_unet.py Signed-off-by: Mark Graham <markgraham539@gmail.com> * Update tests/test_diffusion_model_unet.py Signed-off-by: Mark Graham <markgraham539@gmail.com> --------- Signed-off-by: Mark Graham <markgraham539@gmail.com> Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk> Co-authored-by: Mark Graham <markgraham539@gmail.com>
1 parent a503d77 commit 3da2673

2 files changed

Lines changed: 88 additions & 1 deletion

File tree

generative/networks/nets/diffusion_model_unet.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ class CrossAttnDownBlock(nn.Module):
911911
cross_attention_dim: number of context dimensions to use.
912912
upcast_attention: if True, upcast attention operations to full precision.
913913
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
914+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
914915
"""
915916

916917
def __init__(
@@ -930,6 +931,7 @@ def __init__(
930931
cross_attention_dim: int | None = None,
931932
upcast_attention: bool = False,
932933
use_flash_attention: bool = False,
934+
dropout_cattn: float = 0.0
933935
) -> None:
934936
super().__init__()
935937
self.resblock_updown = resblock_updown
@@ -962,6 +964,7 @@ def __init__(
962964
cross_attention_dim=cross_attention_dim,
963965
upcast_attention=upcast_attention,
964966
use_flash_attention=use_flash_attention,
967+
dropout=dropout_cattn
965968
)
966969
)
967970

@@ -1100,6 +1103,7 @@ def __init__(
11001103
cross_attention_dim: int | None = None,
11011104
upcast_attention: bool = False,
11021105
use_flash_attention: bool = False,
1106+
dropout_cattn: float = 0.0
11031107
) -> None:
11041108
super().__init__()
11051109
self.attention = None
@@ -1123,6 +1127,7 @@ def __init__(
11231127
cross_attention_dim=cross_attention_dim,
11241128
upcast_attention=upcast_attention,
11251129
use_flash_attention=use_flash_attention,
1130+
dropout=dropout_cattn
11261131
)
11271132
self.resnet_2 = ResnetBlock(
11281133
spatial_dims=spatial_dims,
@@ -1266,7 +1271,7 @@ def __init__(
12661271
add_upsample: bool = True,
12671272
resblock_updown: bool = False,
12681273
num_head_channels: int = 1,
1269-
use_flash_attention: bool = False,
1274+
use_flash_attention: bool = False
12701275
) -> None:
12711276
super().__init__()
12721277
self.resblock_updown = resblock_updown
@@ -1363,6 +1368,7 @@ class CrossAttnUpBlock(nn.Module):
13631368
cross_attention_dim: number of context dimensions to use.
13641369
upcast_attention: if True, upcast attention operations to full precision.
13651370
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1371+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
13661372
"""
13671373

13681374
def __init__(
@@ -1382,6 +1388,7 @@ def __init__(
13821388
cross_attention_dim: int | None = None,
13831389
upcast_attention: bool = False,
13841390
use_flash_attention: bool = False,
1391+
dropout_cattn: float = 0.0
13851392
) -> None:
13861393
super().__init__()
13871394
self.resblock_updown = resblock_updown
@@ -1415,6 +1422,7 @@ def __init__(
14151422
cross_attention_dim=cross_attention_dim,
14161423
upcast_attention=upcast_attention,
14171424
use_flash_attention=use_flash_attention,
1425+
dropout=dropout_cattn
14181426
)
14191427
)
14201428

@@ -1478,6 +1486,7 @@ def get_down_block(
14781486
cross_attention_dim: int | None,
14791487
upcast_attention: bool = False,
14801488
use_flash_attention: bool = False,
1489+
dropout_cattn: float = 0.0
14811490
) -> nn.Module:
14821491
if with_attn:
14831492
return AttnDownBlock(
@@ -1509,6 +1518,7 @@ def get_down_block(
15091518
cross_attention_dim=cross_attention_dim,
15101519
upcast_attention=upcast_attention,
15111520
use_flash_attention=use_flash_attention,
1521+
dropout_cattn=dropout_cattn
15121522
)
15131523
else:
15141524
return DownBlock(
@@ -1536,6 +1546,7 @@ def get_mid_block(
15361546
cross_attention_dim: int | None,
15371547
upcast_attention: bool = False,
15381548
use_flash_attention: bool = False,
1549+
dropout_cattn: float = 0.0
15391550
) -> nn.Module:
15401551
if with_conditioning:
15411552
return CrossAttnMidBlock(
@@ -1549,6 +1560,7 @@ def get_mid_block(
15491560
cross_attention_dim=cross_attention_dim,
15501561
upcast_attention=upcast_attention,
15511562
use_flash_attention=use_flash_attention,
1563+
dropout_cattn=dropout_cattn
15521564
)
15531565
else:
15541566
return AttnMidBlock(
@@ -1580,6 +1592,7 @@ def get_up_block(
15801592
cross_attention_dim: int | None,
15811593
upcast_attention: bool = False,
15821594
use_flash_attention: bool = False,
1595+
dropout_cattn: float = 0.0
15831596
) -> nn.Module:
15841597
if with_attn:
15851598
return AttnUpBlock(
@@ -1613,6 +1626,7 @@ def get_up_block(
16131626
cross_attention_dim=cross_attention_dim,
16141627
upcast_attention=upcast_attention,
16151628
use_flash_attention=use_flash_attention,
1629+
dropout_cattn=dropout_cattn
16161630
)
16171631
else:
16181632
return UpBlock(
@@ -1653,6 +1667,7 @@ class DiffusionModelUNet(nn.Module):
16531667
classes.
16541668
upcast_attention: if True, upcast attention operations to full precision.
16551669
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1670+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
16561671
"""
16571672

16581673
def __init__(
@@ -1673,6 +1688,7 @@ def __init__(
16731688
num_class_embeds: int | None = None,
16741689
upcast_attention: bool = False,
16751690
use_flash_attention: bool = False,
1691+
dropout_cattn: float = 0.0
16761692
) -> None:
16771693
super().__init__()
16781694
if with_conditioning is True and cross_attention_dim is None:
@@ -1684,6 +1700,10 @@ def __init__(
16841700
raise ValueError(
16851701
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
16861702
)
1703+
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
1704+
raise ValueError(
1705+
"Dropout cannot be negative or >1.0!"
1706+
)
16871707

16881708
# All number of channels should be multiple of num_groups
16891709
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
@@ -1773,6 +1793,7 @@ def __init__(
17731793
cross_attention_dim=cross_attention_dim,
17741794
upcast_attention=upcast_attention,
17751795
use_flash_attention=use_flash_attention,
1796+
dropout_cattn=dropout_cattn
17761797
)
17771798

17781799
self.down_blocks.append(down_block)
@@ -1790,6 +1811,7 @@ def __init__(
17901811
cross_attention_dim=cross_attention_dim,
17911812
upcast_attention=upcast_attention,
17921813
use_flash_attention=use_flash_attention,
1814+
dropout_cattn=dropout_cattn
17931815
)
17941816

17951817
# up
@@ -1824,6 +1846,7 @@ def __init__(
18241846
cross_attention_dim=cross_attention_dim,
18251847
upcast_attention=upcast_attention,
18261848
use_flash_attention=use_flash_attention,
1849+
dropout_cattn=dropout_cattn
18271850
)
18281851

18291852
self.up_blocks.append(up_block)

tests/test_diffusion_model_unet.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,59 @@
231231
],
232232
]
233233

234+
DROPOUT_OK = [
235+
[
236+
{
237+
"spatial_dims": 2,
238+
"in_channels": 1,
239+
"out_channels": 1,
240+
"num_res_blocks": 1,
241+
"num_channels": (8, 8, 8),
242+
"attention_levels": (False, False, True),
243+
"num_head_channels": 4,
244+
"norm_num_groups": 8,
245+
"with_conditioning": True,
246+
"transformer_num_layers": 1,
247+
"cross_attention_dim": 3,
248+
"dropout_cattn": 0.25
249+
}
250+
],
251+
[
252+
{
253+
"spatial_dims": 2,
254+
"in_channels": 1,
255+
"out_channels": 1,
256+
"num_res_blocks": 1,
257+
"num_channels": (8, 8, 8),
258+
"attention_levels": (False, False, True),
259+
"num_head_channels": 4,
260+
"norm_num_groups": 8,
261+
"with_conditioning": True,
262+
"transformer_num_layers": 1,
263+
"cross_attention_dim": 3
264+
}
265+
],
266+
]
267+
268+
DROPOUT_WRONG = [
269+
[
270+
{
271+
"spatial_dims": 2,
272+
"in_channels": 1,
273+
"out_channels": 1,
274+
"num_res_blocks": 1,
275+
"num_channels": (8, 8, 8),
276+
"attention_levels": (False, False, True),
277+
"num_head_channels": 4,
278+
"norm_num_groups": 8,
279+
"with_conditioning": True,
280+
"transformer_num_layers": 1,
281+
"cross_attention_dim": 3,
282+
"dropout_cattn": 3.0
283+
}
284+
],
285+
]
286+
234287

235288
class TestDiffusionModelUNet2D(unittest.TestCase):
236289
@parameterized.expand(UNCOND_CASES_2D)
@@ -524,6 +577,17 @@ def test_script_conditioned_3d_models(self):
524577
net, torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))
525578
)
526579

580+
# Test dropout specification for cross-attention blocks
581+
@parameterized.expand(DROPOUT_WRONG)
582+
def test_wrong_dropout(self, input_param):
583+
with self.assertRaises(ValueError):
584+
_ = DiffusionModelUNet(**input_param)
585+
586+
@parameterized.expand(DROPOUT_OK)
587+
def test_right_dropout(self, input_param):
588+
_ = DiffusionModelUNet(**input_param)
589+
590+
527591

528592
if __name__ == "__main__":
529593
unittest.main()

0 commit comments

Comments
 (0)