@@ -911,6 +911,7 @@ class CrossAttnDownBlock(nn.Module):
911911 cross_attention_dim: number of context dimensions to use.
912912 upcast_attention: if True, upcast attention operations to full precision.
913913 use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
914+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
914915 """
915916
916917 def __init__ (
@@ -930,6 +931,7 @@ def __init__(
930931 cross_attention_dim : int | None = None ,
931932 upcast_attention : bool = False ,
932933 use_flash_attention : bool = False ,
934+ dropout_cattn : float = 0.0
933935 ) -> None :
934936 super ().__init__ ()
935937 self .resblock_updown = resblock_updown
@@ -962,6 +964,7 @@ def __init__(
962964 cross_attention_dim = cross_attention_dim ,
963965 upcast_attention = upcast_attention ,
964966 use_flash_attention = use_flash_attention ,
967+ dropout = dropout_cattn
965968 )
966969 )
967970
@@ -1100,6 +1103,7 @@ def __init__(
11001103 cross_attention_dim : int | None = None ,
11011104 upcast_attention : bool = False ,
11021105 use_flash_attention : bool = False ,
1106+ dropout_cattn : float = 0.0
11031107 ) -> None :
11041108 super ().__init__ ()
11051109 self .attention = None
@@ -1123,6 +1127,7 @@ def __init__(
11231127 cross_attention_dim = cross_attention_dim ,
11241128 upcast_attention = upcast_attention ,
11251129 use_flash_attention = use_flash_attention ,
1130+ dropout = dropout_cattn
11261131 )
11271132 self .resnet_2 = ResnetBlock (
11281133 spatial_dims = spatial_dims ,
@@ -1266,7 +1271,7 @@ def __init__(
12661271 add_upsample : bool = True ,
12671272 resblock_updown : bool = False ,
12681273 num_head_channels : int = 1 ,
1269- use_flash_attention : bool = False ,
1274+ use_flash_attention : bool = False
12701275 ) -> None :
12711276 super ().__init__ ()
12721277 self .resblock_updown = resblock_updown
@@ -1363,6 +1368,7 @@ class CrossAttnUpBlock(nn.Module):
13631368 cross_attention_dim: number of context dimensions to use.
13641369 upcast_attention: if True, upcast attention operations to full precision.
13651370 use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1371+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
13661372 """
13671373
13681374 def __init__ (
@@ -1382,6 +1388,7 @@ def __init__(
13821388 cross_attention_dim : int | None = None ,
13831389 upcast_attention : bool = False ,
13841390 use_flash_attention : bool = False ,
1391+ dropout_cattn : float = 0.0
13851392 ) -> None :
13861393 super ().__init__ ()
13871394 self .resblock_updown = resblock_updown
@@ -1415,6 +1422,7 @@ def __init__(
14151422 cross_attention_dim = cross_attention_dim ,
14161423 upcast_attention = upcast_attention ,
14171424 use_flash_attention = use_flash_attention ,
1425+ dropout = dropout_cattn
14181426 )
14191427 )
14201428
@@ -1478,6 +1486,7 @@ def get_down_block(
14781486 cross_attention_dim : int | None ,
14791487 upcast_attention : bool = False ,
14801488 use_flash_attention : bool = False ,
1489+ dropout_cattn : float = 0.0
14811490) -> nn .Module :
14821491 if with_attn :
14831492 return AttnDownBlock (
@@ -1509,6 +1518,7 @@ def get_down_block(
15091518 cross_attention_dim = cross_attention_dim ,
15101519 upcast_attention = upcast_attention ,
15111520 use_flash_attention = use_flash_attention ,
1521+ dropout_cattn = dropout_cattn
15121522 )
15131523 else :
15141524 return DownBlock (
@@ -1536,6 +1546,7 @@ def get_mid_block(
15361546 cross_attention_dim : int | None ,
15371547 upcast_attention : bool = False ,
15381548 use_flash_attention : bool = False ,
1549+ dropout_cattn : float = 0.0
15391550) -> nn .Module :
15401551 if with_conditioning :
15411552 return CrossAttnMidBlock (
@@ -1549,6 +1560,7 @@ def get_mid_block(
15491560 cross_attention_dim = cross_attention_dim ,
15501561 upcast_attention = upcast_attention ,
15511562 use_flash_attention = use_flash_attention ,
1563+ dropout_cattn = dropout_cattn
15521564 )
15531565 else :
15541566 return AttnMidBlock (
@@ -1580,6 +1592,7 @@ def get_up_block(
15801592 cross_attention_dim : int | None ,
15811593 upcast_attention : bool = False ,
15821594 use_flash_attention : bool = False ,
1595+ dropout_cattn : float = 0.0
15831596) -> nn .Module :
15841597 if with_attn :
15851598 return AttnUpBlock (
@@ -1613,6 +1626,7 @@ def get_up_block(
16131626 cross_attention_dim = cross_attention_dim ,
16141627 upcast_attention = upcast_attention ,
16151628 use_flash_attention = use_flash_attention ,
1629+ dropout_cattn = dropout_cattn
16161630 )
16171631 else :
16181632 return UpBlock (
@@ -1653,6 +1667,7 @@ class DiffusionModelUNet(nn.Module):
16531667 classes.
16541668 upcast_attention: if True, upcast attention operations to full precision.
16551669 use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1670+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
16561671 """
16571672
16581673 def __init__ (
@@ -1673,6 +1688,7 @@ def __init__(
16731688 num_class_embeds : int | None = None ,
16741689 upcast_attention : bool = False ,
16751690 use_flash_attention : bool = False ,
1691+ dropout_cattn : float = 0.0
16761692 ) -> None :
16771693 super ().__init__ ()
16781694 if with_conditioning is True and cross_attention_dim is None :
@@ -1684,6 +1700,10 @@ def __init__(
16841700 raise ValueError (
16851701 "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
16861702 )
1703+ if dropout_cattn > 1.0 or dropout_cattn < 0.0 :
1704+ raise ValueError (
1705+ "Dropout cannot be negative or >1.0!"
1706+ )
16871707
16881708 # All number of channels should be multiple of num_groups
16891709 if any ((out_channel % norm_num_groups ) != 0 for out_channel in num_channels ):
@@ -1773,6 +1793,7 @@ def __init__(
17731793 cross_attention_dim = cross_attention_dim ,
17741794 upcast_attention = upcast_attention ,
17751795 use_flash_attention = use_flash_attention ,
1796+ dropout_cattn = dropout_cattn
17761797 )
17771798
17781799 self .down_blocks .append (down_block )
@@ -1790,6 +1811,7 @@ def __init__(
17901811 cross_attention_dim = cross_attention_dim ,
17911812 upcast_attention = upcast_attention ,
17921813 use_flash_attention = use_flash_attention ,
1814+ dropout_cattn = dropout_cattn
17931815 )
17941816
17951817 # up
@@ -1824,6 +1846,7 @@ def __init__(
18241846 cross_attention_dim = cross_attention_dim ,
18251847 upcast_attention = upcast_attention ,
18261848 use_flash_attention = use_flash_attention ,
1849+ dropout_cattn = dropout_cattn
18271850 )
18281851
18291852 self .up_blocks .append (up_block )
0 commit comments