22import torch .nn as nn
33import torch .nn .functional as F
44
5- from modules .commons .common_layers import SinusoidalPosEmb , SwiGLU , Transpose
5+ from modules .commons .common_layers import SinusoidalPosEmb , SwiGLU , ATanGLU , Transpose
66from utils .hparams import hparams
77
88
99class LYNXNet2Block (nn .Module ):
10- def __init__ (self , dim , expansion_factor , kernel_size = 31 , dropout = 0. ):
10+ def __init__ (self , dim , expansion_factor , kernel_size = 31 , dropout = 0. , glu_type = 'swiglu' ):
1111 super ().__init__ ()
1212 inner_dim = int (dim * expansion_factor )
13+ if glu_type == 'swiglu' :
14+ _glu = SwiGLU ()
15+ elif glu_type == 'atanglu' :
16+ _glu = ATanGLU ()
17+ else :
18+ raise ValueError (f'{ glu_type } is not a valid activation' )
1319 if float (dropout ) > 0. :
1420 _dropout = nn .Dropout (dropout )
1521 else :
@@ -20,9 +26,9 @@ def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
2026 nn .Conv1d (dim , dim , kernel_size = kernel_size , padding = kernel_size // 2 , groups = dim ),
2127 Transpose ((1 , 2 )),
2228 nn .Linear (dim , inner_dim * 2 ),
23- SwiGLU () ,
29+ _glu ,
2430 nn .Linear (inner_dim , inner_dim * 2 ),
25- SwiGLU () ,
31+ _glu ,
2632 nn .Linear (inner_dim , dim ),
2733 _dropout
2834 )
@@ -33,7 +39,7 @@ def forward(self, x):
3339
3440class LYNXNet2 (nn .Module ):
3541 def __init__ (self , in_dims , n_feats , * , num_layers = 6 , num_channels = 512 , expansion_factor = 1 , kernel_size = 31 ,
36- dropout = 0.0 , use_conditioner_cache = False ):
42+ dropout = 0.0 , use_conditioner_cache = False , glu_type = 'swiglu' ):
3743 """
3844 LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
3945 """
@@ -59,7 +65,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
5965 dim = num_channels ,
6066 expansion_factor = expansion_factor ,
6167 kernel_size = kernel_size ,
62- dropout = dropout
68+ dropout = dropout ,
69+ glu_type = glu_type
6370 )
6471 for i in range (num_layers )
6572 ]
0 commit comments