Skip to content

Commit 51d3d3d

Browse files
committed
Merge branch 'lynxnet2' into muon_lynxnet2
2 parents 19a527b + 13406a2 commit 51d3d3d

9 files changed

Lines changed: 151 additions & 58 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,12 @@ sampling_steps: 20
7070
diff_accelerator: ddim
7171
diff_speedup: 10
7272
hidden_size: 256
73-
backbone_type: 'lynxnet'
73+
backbone_type: 'lynxnet2'
7474
backbone_args:
7575
num_channels: 1024
7676
num_layers: 6
7777
kernel_size: 31
7878
dropout_rate: 0.0
79-
strong_cond: true
8079
main_loss_type: l2
8180
main_loss_log_norm: false
8281
schedule_type: 'linear'

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ T_start: 0.4
7676
T_start_infer: 0.4
7777
K_step: 300
7878
K_step_infer: 300
79-
backbone_type: 'lynxnet'
79+
backbone_type: 'lynxnet2'
8080
backbone_args:
8181
num_channels: 1024
8282
num_layers: 6
8383
kernel_size: 31
8484
dropout_rate: 0.0
85-
strong_cond: true
8685
#backbone_type: 'wavenet'
8786
#backbone_args:
8887
# num_channels: 512

configs/templates/config_variance.yaml

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,31 +94,29 @@ pitch_prediction_args:
9494
pitd_clip_min: -12.0
9595
pitd_clip_max: 12.0
9696
repeat_bins: 64
97-
backbone_type: 'wavenet'
98-
backbone_args:
99-
num_layers: 20
100-
num_channels: 256
101-
dilation_cycle_length: 5
102-
# backbone_type: 'lynxnet'
97+
# backbone_type: 'wavenet'
10398
# backbone_args:
104-
# num_layers: 6
105-
# num_channels: 512
106-
# dropout_rate: 0.0
107-
# strong_cond: true
99+
# num_layers: 20
100+
# num_channels: 256
101+
# dilation_cycle_length: 5
102+
backbone_type: 'lynxnet2'
103+
backbone_args:
104+
num_layers: 6
105+
num_channels: 512
106+
dropout_rate: 0.0
108107

109108
variances_prediction_args:
110109
total_repeat_bins: 48
111-
backbone_type: 'wavenet'
112-
backbone_args:
113-
num_layers: 10
114-
num_channels: 192
115-
dilation_cycle_length: 4
116-
# backbone_type: 'lynxnet'
110+
# backbone_type: 'wavenet'
117111
# backbone_args:
118-
# num_layers: 6
119-
# num_channels: 384
120-
# dropout_rate: 0.0
121-
# strong_cond: true
112+
# num_layers: 10
113+
# num_channels: 192
114+
# dilation_cycle_length: 4
115+
backbone_type: 'lynxnet2'
116+
backbone_args:
117+
num_layers: 6
118+
num_channels: 384
119+
dropout_rate: 0.0
122120

123121
lambda_dur_loss: 1.0
124122
lambda_pitch_loss: 1.0

configs/variance.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ pitch_prediction_args:
6565
pitd_clip_min: -12.0
6666
pitd_clip_max: 12.0
6767
repeat_bins: 64
68-
backbone_type: 'wavenet'
68+
backbone_type: 'lynxnet2'
6969
backbone_args:
70-
num_layers: 20
71-
num_channels: 256
72-
dilation_cycle_length: 5
70+
num_layers: 6
71+
num_channels: 512
72+
dropout_rate: 0.0
7373

7474
energy_db_min: -96.0
7575
energy_db_max: -12.0
@@ -88,11 +88,11 @@ tension_smooth_width: 0.12
8888

8989
variances_prediction_args:
9090
total_repeat_bins: 48
91-
backbone_type: 'wavenet'
91+
backbone_type: 'lynxnet2'
9292
backbone_args:
93-
num_layers: 10
94-
num_channels: 192
95-
dilation_cycle_length: 4
93+
num_layers: 6
94+
num_channels: 384
95+
dropout_rate: 0.0
9696

9797
lambda_dur_loss: 1.0
9898
lambda_pitch_loss: 1.0

modules/backbones/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import torch.nn
22
from modules.backbones.wavenet import WaveNet
33
from modules.backbones.lynxnet import LYNXNet
4+
from modules.backbones.lynxnet2 import LYNXNet2
45
from utils import filter_kwargs
56

67
BACKBONES = {
78
'wavenet': WaveNet,
8-
'lynxnet': LYNXNet
9+
'lynxnet': LYNXNet,
10+
'lynxnet2': LYNXNet2,
911
}
1012

1113

modules/backbones/lynxnet.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,10 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88

9-
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU
9+
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose
1010
from utils.hparams import hparams
1111

1212

13-
class Conv1d(torch.nn.Conv1d):
14-
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
16-
nn.init.kaiming_normal_(self.weight)
17-
18-
19-
class Transpose(nn.Module):
20-
def __init__(self, dims):
21-
super().__init__()
22-
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
23-
self.dims = dims
24-
25-
def forward(self, x):
26-
return x.transpose(*self.dims)
27-
28-
2913
class LYNXConvModule(nn.Module):
3014
@staticmethod
3115
def calc_same_padding(kernel_size):
@@ -150,7 +134,7 @@ def forward(self, spec, diffusion_step, cond):
150134
# post-norm
151135
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
152136

153-
# MLP and GLU
137+
# output_projection
154138
x = self.output_projection(x) # [B, 128, T]
155139

156140
if self.n_feats == 1:

modules/backbones/lynxnet2.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose
6+
from utils.hparams import hparams
7+
8+
9+
class LYNXNet2Block(nn.Module):
10+
def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
11+
super().__init__()
12+
inner_dim = int(dim * expansion_factor)
13+
if float(dropout) > 0.:
14+
_dropout = nn.Dropout(dropout)
15+
else:
16+
_dropout = nn.Identity()
17+
self.net = nn.Sequential(
18+
nn.LayerNorm(dim),
19+
Transpose((1, 2)),
20+
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
21+
Transpose((1, 2)),
22+
nn.Linear(dim, inner_dim * 2),
23+
SwiGLU(),
24+
nn.Linear(inner_dim, inner_dim * 2),
25+
SwiGLU(),
26+
nn.Linear(inner_dim, dim),
27+
_dropout
28+
)
29+
30+
def forward(self, x):
31+
return x + self.net(x)
32+
33+
34+
class LYNXNet2(nn.Module):
35+
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31,
36+
dropout=0.0):
37+
"""
38+
LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
39+
"""
40+
super().__init__()
41+
self.in_dims = in_dims
42+
self.n_feats = n_feats
43+
self.input_projection = nn.Linear(in_dims * n_feats, num_channels)
44+
self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels)
45+
self.diffusion_embedding = nn.Sequential(
46+
SinusoidalPosEmb(num_channels),
47+
nn.Linear(num_channels, num_channels * 4),
48+
nn.GELU(),
49+
nn.Linear(num_channels * 4, num_channels),
50+
)
51+
self.residual_layers = nn.ModuleList(
52+
[
53+
LYNXNet2Block(
54+
dim=num_channels,
55+
expansion_factor=expansion_factor,
56+
kernel_size=kernel_size,
57+
dropout=dropout
58+
)
59+
for i in range(num_layers)
60+
]
61+
)
62+
self.norm = nn.LayerNorm(num_channels)
63+
self.output_projection = nn.Linear(num_channels, in_dims * n_feats)
64+
nn.init.kaiming_normal_(self.input_projection.weight)
65+
nn.init.kaiming_normal_(self.conditioner_projection.weight)
66+
nn.init.zeros_(self.output_projection.weight)
67+
68+
def forward(self, spec, diffusion_step, cond):
69+
"""
70+
:param spec: [B, F, M, T]
71+
:param diffusion_step: [B, 1]
72+
:param cond: [B, H, T]
73+
:return:
74+
"""
75+
76+
if self.n_feats == 1:
77+
x = spec[:, 0] # [B, M, T]
78+
else:
79+
x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T]
80+
81+
x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M]
82+
x = x + self.conditioner_projection(cond.transpose(1, 2))
83+
x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1)
84+
85+
for layer in self.residual_layers:
86+
x = layer(x)
87+
88+
# post-norm
89+
x = self.norm(x)
90+
91+
# output projection
92+
x = self.output_projection(x).transpose(1, 2) # [B, 128, T]
93+
94+
if self.n_feats == 1:
95+
x = x[:, None, :, :]
96+
else:
97+
# This is the temporary solution since PyTorch 1.13
98+
# does not support exporting aten::unflatten to ONNX
99+
# x = x.unflatten(dim=1, sizes=(self.n_feats, self.in_dims))
100+
x = x.reshape(-1, self.n_feats, self.in_dims, x.shape[2])
101+
return x

modules/backbones/wavenet.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,10 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77

8-
from modules.commons.common_layers import SinusoidalPosEmb
8+
from modules.commons.common_layers import SinusoidalPosEmb, Conv1d
99
from utils.hparams import hparams
1010

1111

12-
class Conv1d(torch.nn.Conv1d):
13-
def __init__(self, *args, **kwargs):
14-
super().__init__(*args, **kwargs)
15-
nn.init.kaiming_normal_(self.weight)
16-
17-
1812
class ResidualBlock(nn.Module):
1913
def __init__(self, encoder_hidden, residual_channels, dilation):
2014
super().__init__()

modules/commons/common_layers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,22 @@ def forward(self, x):
117117
return out * F.silu(gate)
118118

119119

120+
class Conv1d(torch.nn.Conv1d):
121+
def __init__(self, *args, **kwargs):
122+
super().__init__(*args, **kwargs)
123+
nn.init.kaiming_normal_(self.weight)
124+
125+
126+
class Transpose(nn.Module):
127+
def __init__(self, dims):
128+
super().__init__()
129+
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
130+
self.dims = dims
131+
132+
def forward(self, x):
133+
return x.transpose(*self.dims)
134+
135+
120136
class TransformerFFNLayer(nn.Module):
121137
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'):
122138
super().__init__()

0 commit comments

Comments
 (0)