Skip to content

Commit f591534

Browse files
committed
support atanglu
1 parent 0db91f2 commit f591534

6 files changed

Lines changed: 35 additions & 6 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ backbone_args:
7878
kernel_size: 31
7979
dropout_rate: 0.0
8080
use_conditioner_cache: true
81+
glu_type: 'atanglu'
8182
main_loss_type: l2
8283
main_loss_log_norm: false
8384
schedule_type: 'linear'

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ backbone_args:
8484
kernel_size: 31
8585
dropout_rate: 0.0
8686
use_conditioner_cache: true
87+
glu_type: 'atanglu'
8788
#backbone_type: 'wavenet'
8889
#backbone_args:
8990
# num_channels: 512

configs/templates/config_variance.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ pitch_prediction_args:
106106
num_channels: 512
107107
dropout_rate: 0.0
108108
use_conditioner_cache: true
109+
glu_type: 'atanglu'
109110

110111
variances_prediction_args:
111112
total_repeat_bins: 48
@@ -120,6 +121,7 @@ variances_prediction_args:
120121
num_channels: 384
121122
dropout_rate: 0.0
122123
use_conditioner_cache: true
124+
glu_type: 'atanglu'
123125

124126
lambda_dur_loss: 1.0
125127
lambda_pitch_loss: 1.0

configs/variance.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pitch_prediction_args:
7272
num_channels: 512
7373
dropout_rate: 0.0
7474
use_conditioner_cache: true
75+
glu_type: 'atanglu'
7576

7677
energy_db_min: -96.0
7778
energy_db_max: -12.0
@@ -96,6 +97,7 @@ variances_prediction_args:
9697
num_channels: 384
9798
dropout_rate: 0.0
9899
use_conditioner_cache: true
100+
glu_type: 'atanglu'
99101

100102
lambda_dur_loss: 1.0
101103
lambda_pitch_loss: 1.0

modules/backbones/lynxnet2.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose
5+
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose
66
from utils.hparams import hparams
77

88

99
class LYNXNet2Block(nn.Module):
10-
def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
10+
def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0., glu_type='swiglu'):
1111
super().__init__()
1212
inner_dim = int(dim * expansion_factor)
13+
if glu_type == 'swiglu':
14+
_glu = SwiGLU()
15+
elif glu_type == 'atanglu':
16+
_glu = ATanGLU()
17+
else:
18+
raise ValueError(f'{glu_type} is not a valid activation')
1319
if float(dropout) > 0.:
1420
_dropout = nn.Dropout(dropout)
1521
else:
@@ -20,9 +26,9 @@ def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
2026
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
2127
Transpose((1, 2)),
2228
nn.Linear(dim, inner_dim * 2),
23-
SwiGLU(),
29+
_glu,
2430
nn.Linear(inner_dim, inner_dim * 2),
25-
SwiGLU(),
31+
_glu,
2632
nn.Linear(inner_dim, dim),
2733
_dropout
2834
)
@@ -33,7 +39,7 @@ def forward(self, x):
3339

3440
class LYNXNet2(nn.Module):
3541
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31,
36-
dropout=0.0, use_conditioner_cache=False):
42+
dropout=0.0, use_conditioner_cache=False, glu_type='swiglu'):
3743
"""
3844
LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
3945
"""
@@ -59,7 +65,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
5965
dim=num_channels,
6066
expansion_factor=expansion_factor,
6167
kernel_size=kernel_size,
62-
dropout=dropout
68+
dropout=dropout,
69+
glu_type=glu_type
6370
)
6471
for i in range(num_layers)
6572
]

modules/commons/common_layers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ def forward(self, x):
128128
return out * gate
129129

130130

131+
class ATanGLU(nn.Module):
132+
# ArcTan-Applies the gated linear unit function.
133+
def __init__(self, dim=-1):
134+
super().__init__()
135+
self.dim = dim
136+
137+
def forward(self, x):
138+
# out, gate = x.chunk(2, dim=self.dim)
139+
# Using torch.split instead of chunk for ONNX export compatibility.
140+
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
141+
return out * torch.atan(gate)
142+
143+
131144
class KaimingNormalConv1d(torch.nn.Conv1d):
132145
def __init__(self, *args, **kwargs):
133146
super().__init__(*args, **kwargs)
@@ -160,6 +173,9 @@ def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gel
160173
elif self.act == 'swiglu':
161174
self.act_fn = SwiGLU()
162175
filter_size_1 = filter_size * 2
176+
elif self.act == 'atanglu':
177+
self.act_fn = ATanGLU()
178+
filter_size_1 = filter_size * 2
163179
else:
164180
raise ValueError(f'{act} is not a valid activation')
165181
self.ffn_1 = nn.Conv1d(hidden_size, filter_size_1, kernel_size, padding=kernel_size // 2)

0 commit comments

Comments
 (0)