Skip to content

Commit 47ad6f6

Browse files
yxlllcKakaruHayate
andauthored
Stretch embed (#274)
* Acoustic SR_embed (#270) * Acoustic SR_embed / Cosine annealing * del 'WarmupCosineSchedule' in config del 'WarmupCosineSchedule' in config * Fix the precision problem of 'StretchRegulator' in ONNX model * fix some odds and ends... * set 'use_stretch_embed' true on default * fix some odds and ends... * adjust * add stretch embed to variance models * fix * fix * fix * optimize * using lookup table for optimization * update --------- Co-authored-by: Kakaru <97896816+KakaruHayate@users.noreply.github.com>
1 parent 6df0ee9 commit 47ad6f6

11 files changed

Lines changed: 102 additions & 33 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ timesteps: 1000
6464
max_beta: 0.02
6565
enc_ffn_kernel_size: 3
6666
use_rope: true
67+
use_stretch_embed: true
6768
use_variance_scaling: true
6869
rel_pos: true
6970
sampling_algorithm: euler

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ augmentation_args:
7171
diffusion_type: reflow
7272
enc_ffn_kernel_size: 3
7373
use_rope: true
74+
use_stretch_embed: true
7475
use_variance_scaling: true
7576
use_shallow_diffusion: true
7677
T_start: 0.4

configs/templates/config_variance.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ tension_logit_max: 10.0
6565

6666
enc_ffn_kernel_size: 3
6767
use_rope: true
68+
use_stretch_embed: false
6869
use_variance_scaling: true
6970
hidden_size: 384
7071
dur_prediction_args:

configs/variance.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ predict_tension: false
3636

3737
enc_ffn_kernel_size: 3
3838
use_rope: true
39+
use_stretch_embed: false
3940
use_variance_scaling: true
4041
rel_pos: true
4142
hidden_size: 384

deployment/modules/fastspeech2.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def forward(
7373
txt_embed = self.txt_embed(tokens)
7474
durations = durations * (tokens > 0)
7575
mel2ph = self.lr(durations)
76+
_mel2ph = mel2ph
7677
f0 = f0 * (mel2ph > 0)
7778
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
7879
if self.use_variance_scaling:
@@ -92,6 +93,14 @@ def forward(
9293
encoded = F.pad(encoded, (0, 0, 1, 0))
9394
condition = torch.gather(encoded, 1, mel2ph)
9495

96+
if self.use_stretch_embed:
97+
stretch = torch.round(1000 * self.sr(_mel2ph, durations))
98+
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
99+
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
100+
condition += stretch_embed
101+
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
102+
condition += stretch_embed_rnn_out
103+
95104
if self.f0_embed_type == 'discrete':
96105
pitch = f0_to_coarse(f0)
97106
pitch_embed = self.pitch_embed(pitch)
@@ -102,30 +111,27 @@ def forward(
102111

103112
if self.use_variance_embeds:
104113
variance_embeds = torch.stack([
105-
self.variance_embeds[v_name](variances[v_name][:, :, None])
106-
* self.variance_scaling_factor[v_name]
114+
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name])
107115
for v_name in self.variance_embed_list
108116
], dim=-1).sum(-1)
109117
condition += variance_embeds
110118

111119
if hparams['use_key_shift_embed']:
112120
if hasattr(self, 'frozen_key_shift'):
113-
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
121+
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None] * self.variance_scaling_factor['key_shift'])
114122
else:
115123
gender = torch.clip(gender, min=-1., max=1.)
116124
gender_mask = (gender < 0.).float()
117125
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
118-
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
119-
key_shift_embed *= self.variance_scaling_factor['key_shift']
126+
key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift'])
120127
condition += key_shift_embed
121128

122129
if hparams['use_speed_embed']:
123130
if velocity is not None:
124131
velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
125-
speed_embed = self.speed_embed(velocity[:, :, None])
132+
speed_embed = self.speed_embed(velocity[:, :, None] * self.variance_scaling_factor['speed'])
126133
else:
127-
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
128-
speed_embed *= self.variance_scaling_factor['speed']
134+
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None] * self.variance_scaling_factor['speed'])
129135
condition += speed_embed
130136

131137
if hparams['use_spk_id']:

deployment/modules/toplevel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,30 @@ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None):
211211
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
212212
return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed)
213213

214-
def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):
214+
def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=False):
215215
mel2x = self.lr(x_dur)
216+
_mel2x = mel2x
216217
if x_dim is not None:
217218
x_src = F.pad(x_src, [0, 0, 1, 0])
218219
mel2x = mel2x[..., None].repeat([1, 1, x_dim])
219220
else:
220221
x_src = F.pad(x_src, [1, 0])
221222
x_cond = torch.gather(x_src, 1, mel2x)
223+
if self.use_stretch_embed and check_stretch_embed:
224+
stretch = torch.round(1000 * self.sr(_mel2x, x_dur))
225+
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
226+
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(x_cond)
227+
x_cond += stretch_embed
228+
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(x_cond)
229+
x_cond += stretch_embed_rnn_out
222230
return x_cond
223231

224232
def forward_pitch_preprocess(
225233
self, encoder_out, ph_dur,
226234
note_midi=None, note_rest=None, note_dur=None, note_glide=None,
227235
pitch=None, expr=None, retake=None, spk_embed=None
228236
):
229-
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
237+
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True)
230238
if self.use_melody_encoder:
231239
if self.melody_encoder.use_glide_embed and note_glide is None:
232240
note_glide = torch.LongTensor([[0]]).to(encoder_out.device)
@@ -280,7 +288,7 @@ def forward_variance_preprocess(
280288
self, encoder_out, ph_dur, pitch,
281289
variances: dict = None, retake=None, spk_embed=None
282290
):
283-
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
291+
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True)
284292
if self.use_variance_scaling:
285293
variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
286294
else:
@@ -290,7 +298,7 @@ def forward_variance_preprocess(
290298
for v_retake in (~retake).split(1, dim=2)
291299
]
292300
variance_embeds = [
293-
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name]
301+
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_retake_scaling[v_name]) * v_masks
294302
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
295303
]
296304
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)

modules/commons/common_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,6 @@ def forward(self, x):
325325
half_dim = self.dim // 2
326326
emb = math.log(10000) / (half_dim - 1)
327327
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
328-
emb = x[:, None] * emb[None, :]
328+
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
329329
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
330330
return emb

modules/fastspeech/acoustic_encoder.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from modules.commons.common_layers import (
66
NormalInitEmbedding as Embedding,
77
XavierUniformInitLinear as Linear,
8+
SinusoidalPosEmb,
89
)
9-
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur
10+
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator
1011
from utils.hparams import hparams
1112
from utils.phoneme_utils import PAD_INDEX
1213

@@ -18,6 +19,19 @@ def __init__(self, vocab_size):
1819
self.use_lang_id = hparams.get('use_lang_id', False)
1920
if self.use_lang_id:
2021
self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0)
22+
23+
self.use_stretch_embed = hparams.get('use_stretch_embed', None)
24+
assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)"
25+
if self.use_stretch_embed:
26+
self.sr = StretchRegulator()
27+
self.stretch_embed = nn.Sequential(
28+
SinusoidalPosEmb(hparams['hidden_size']),
29+
nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4),
30+
nn.GELU(),
31+
nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']),
32+
)
33+
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)
34+
2135
self.dur_embed = Linear(1, hparams['hidden_size'])
2236
self.encoder = FastSpeech2Encoder(
2337
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
@@ -84,20 +98,17 @@ def __init__(self, vocab_size):
8498
def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances):
8599
if self.use_variance_embeds:
86100
variance_embeds = torch.stack([
87-
self.variance_embeds[v_name](variances[v_name][:, :, None])
88-
* self.variance_scaling_factor[v_name]
101+
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name])
89102
for v_name in self.variance_embed_list
90103
], dim=-1).sum(-1)
91104
condition += variance_embeds
92105

93106
if self.use_key_shift_embed:
94-
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
95-
key_shift_embed *= self.variance_scaling_factor['key_shift']
107+
key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift'])
96108
condition += key_shift_embed
97109

98110
if self.use_speed_embed:
99-
speed_embed = self.speed_embed(speed[:, :, None])
100-
speed_embed *= self.variance_scaling_factor['speed']
111+
speed_embed = self.speed_embed(speed[:, :, None] * self.variance_scaling_factor['speed'])
101112
condition += speed_embed
102113

103114
return condition
@@ -109,11 +120,11 @@ def forward(
109120
**kwargs
110121
):
111122
txt_embed = self.txt_embed(txt_tokens)
112-
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float()
123+
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1])
113124
if self.use_variance_scaling:
114-
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None]))
125+
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None].float()))
115126
else:
116-
dur_embed = self.dur_embed(dur[:, :, None])
127+
dur_embed = self.dur_embed(dur[:, :, None].float())
117128
if self.use_lang_id:
118129
lang_embed = self.lang_embed(languages)
119130
extra_embed = dur_embed + lang_embed
@@ -125,6 +136,19 @@ def forward(
125136
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
126137
condition = torch.gather(encoder_out, 1, mel2ph_)
127138

139+
if self.use_stretch_embed:
140+
stretch = torch.round(1000 * self.sr(mel2ph, dur))
141+
if self.training and stretch.numel() > 1000:
142+
# construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000)
143+
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
144+
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
145+
else:
146+
stretch_embed = self.stretch_embed(stretch)
147+
condition += stretch_embed
148+
self.stretch_embed_rnn.flatten_parameters()
149+
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
150+
condition = condition + stretch_embed_rnn_out
151+
128152
if self.use_spk_id:
129153
spk_mix_embed = kwargs.get('spk_mix_embed')
130154
if spk_mix_embed is not None:

modules/fastspeech/tts_modules.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,13 @@ def forward(self, mel2ph, dur=None):
347347
"""
348348
if dur is None:
349349
dur = mel2ph_to_dur(mel2ph, mel2ph.max())
350-
dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero
350+
dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero
351351
mel2dur = torch.gather(dur, 1, mel2ph)
352352
bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1])
353-
bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True)
354-
stretch_delta = 1 - bound_mask * mel2dur
355-
stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0)
353+
stretch_delta = 1 - bound_mask * mel2dur[:, :-1]
354+
stretch_delta = F.pad(stretch_delta, [1, 0])
356355
stretch_denorm = torch.cumsum(stretch_delta, dim=1)
357-
stretch = stretch_denorm / mel2dur
356+
stretch = stretch_denorm.float() / mel2dur
358357
return stretch * (mel2ph > 0)
359358

360359

modules/toplevel.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
from modules.aux_decoder import AuxDecoderAdaptor
1111
from modules.commons.common_layers import (
1212
XavierUniformInitLinear as Linear,
13-
NormalInitEmbedding as Embedding
13+
NormalInitEmbedding as Embedding,
14+
SinusoidalPosEmb
1415
)
1516
from modules.core import (
1617
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion,
1718
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
1819
)
1920
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
2021
from modules.fastspeech.param_adaptor import ParameterAdaptorModule
21-
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator
22+
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator, StretchRegulator
2223
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder
2324
from utils.hparams import hparams
2425

@@ -133,6 +134,18 @@ def __init__(self, vocab_size):
133134
self.predict_dur = hparams['predict_dur']
134135
self.predict_pitch = hparams['predict_pitch']
135136

137+
self.use_stretch_embed = hparams.get('use_stretch_embed', None)
138+
assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)"
139+
if self.use_stretch_embed and (self.predict_pitch or self.predict_variances):
140+
self.sr = StretchRegulator()
141+
self.stretch_embed = nn.Sequential(
142+
SinusoidalPosEmb(hparams['hidden_size']),
143+
nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4),
144+
nn.GELU(),
145+
nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']),
146+
)
147+
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)
148+
136149
self.use_spk_id = hparams['use_spk_id']
137150
if self.use_spk_id:
138151
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size'])
@@ -255,6 +268,19 @@ def forward(
255268
mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']])
256269
condition = torch.gather(encoder_out, 1, mel2ph_)
257270

271+
if self.use_stretch_embed:
272+
stretch = torch.round(1000 * self.sr(mel2ph, ph_dur))
273+
if self.training and stretch.numel() > 1000:
274+
# construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000)
275+
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
276+
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
277+
else:
278+
stretch_embed = self.stretch_embed(stretch)
279+
condition += stretch_embed
280+
self.stretch_embed_rnn.flatten_parameters()
281+
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
282+
condition = condition + stretch_embed_rnn_out
283+
258284
if self.use_spk_id:
259285
condition += spk_embed
260286

@@ -326,7 +352,7 @@ def forward(
326352

327353
if variance_retake is not None:
328354
variance_embeds = [
329-
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name]
355+
self.variance_embeds[v_name](v_input[:, :, None] * self.variance_retake_scaling[v_name]) * ~variance_retake[v_name][:, :, None]
330356
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
331357
]
332358
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)

0 commit comments

Comments
 (0)