Skip to content

Commit 954e41c

Browse files
committed
variance scaling
1 parent 14c3609 commit 954e41c

6 files changed

Lines changed: 47 additions & 9 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ timesteps: 1000
6464
max_beta: 0.02
6565
enc_ffn_kernel_size: 3
6666
use_rope: true
67+
use_variance_scaling: true
6768
rel_pos: true
6869
sampling_algorithm: euler
6970
sampling_steps: 20

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ augmentation_args:
7171
diffusion_type: reflow
7272
enc_ffn_kernel_size: 3
7373
use_rope: true
74+
use_variance_scaling: true
7475
use_shallow_diffusion: true
7576
T_start: 0.4
7677
T_start_infer: 0.4

configs/templates/config_variance.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ tension_logit_max: 10.0
6565

6666
enc_ffn_kernel_size: 3
6767
use_rope: true
68+
use_variance_scaling: true
6869
hidden_size: 256
6970
dur_prediction_args:
7071
arch: resnet
@@ -78,7 +79,7 @@ dur_prediction_args:
7879
lambda_wdur_loss: 1.0
7980
lambda_sdur_loss: 3.0
8081

81-
use_melody_encoder: false
82+
use_melody_encoder: true
8283
melody_encoder_args:
8384
hidden_size: 128
8485
enc_layers: 4

configs/variance.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ predict_tension: false
3636

3737
enc_ffn_kernel_size: 3
3838
use_rope: true
39+
use_variance_scaling: true
3940
rel_pos: true
4041
hidden_size: 256
4142

@@ -51,7 +52,7 @@ dur_prediction_args:
5152
lambda_wdur_loss: 1.0
5253
lambda_sdur_loss: 3.0
5354

54-
use_melody_encoder: false
55+
use_melody_encoder: true
5556
melody_encoder_args:
5657
hidden_size: 128
5758
enc_layers: 4

modules/fastspeech/acoustic_encoder.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def __init__(self, vocab_size):
4949
for v_name in self.variance_embed_list
5050
})
5151

52+
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
53+
if self.use_variance_scaling:
54+
self.variance_scaling_factor = {
55+
'energy': 1. / 96,
56+
'breathiness': 1. / 96,
57+
'voicing': 1. / 96,
58+
'tension': 0.1,
59+
'key_shift': 1. / 12,
60+
'speed': 1.
61+
}
62+
else:
63+
self.variance_scaling_factor = {
64+
'energy': 1.,
65+
'breathiness': 1.,
66+
'voicing': 1.,
67+
'tension': 1.,
68+
'key_shift': 1.,
69+
'speed': 1.
70+
}
71+
5272
self.use_key_shift_embed = hparams.get('use_key_shift_embed', False)
5373
if self.use_key_shift_embed:
5474
self.key_shift_embed = Linear(1, hparams['hidden_size'])
@@ -64,17 +84,20 @@ def __init__(self, vocab_size):
6484
def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances):
6585
if self.use_variance_embeds:
6686
variance_embeds = torch.stack([
67-
self.variance_embeds[v_name](variances[v_name][:, :, None])
87+
self.variance_embeds[v_name](variances[v_name][:, :, None])
88+
* self.variance_scaling_factor[v_name]
6889
for v_name in self.variance_embed_list
6990
], dim=-1).sum(-1)
7091
condition += variance_embeds
7192

7293
if self.use_key_shift_embed:
7394
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
95+
key_shift_embed *= self.variance_scaling_factor['key_shift']
7496
condition += key_shift_embed
7597

7698
if self.use_speed_embed:
7799
speed_embed = self.speed_embed(speed[:, :, None])
100+
speed_embed *= self.variance_scaling_factor['speed']
78101
condition += speed_embed
79102

80103
return condition
@@ -87,7 +110,10 @@ def forward(
87110
):
88111
txt_embed = self.txt_embed(txt_tokens)
89112
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float()
90-
dur_embed = self.dur_embed(dur[:, :, None])
113+
if self.use_variance_scaling:
114+
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None]))
115+
else:
116+
dur_embed = self.dur_embed(dur[:, :, None])
91117
if self.use_lang_id:
92118
lang_embed = self.lang_embed(languages)
93119
extra_embed = dur_embed + lang_embed

modules/fastspeech/variance_encoder.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, vocab_size):
1717
self.predict_dur = hparams['predict_dur']
1818
self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme'
1919
self.use_lang_id = hparams['use_lang_id']
20-
20+
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
2121
self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX)
2222
if self.use_lang_id:
2323
self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0)
@@ -80,9 +80,11 @@ def forward(
8080
word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph]
8181
word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None])
8282
extra_embed = onset_embed + word_dur_embed
83+
elif self.use_variance_scaling:
84+
extra_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None])
8385
else:
84-
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
85-
extra_embed = ph_dur_embed
86+
extra_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
87+
8688
if self.use_lang_id:
8789
lang_embed = self.lang_embed(languages)
8890
extra_embed += lang_embed
@@ -109,6 +111,7 @@ def get_hparam(key):
109111

110112
# MIDI inputs
111113
hidden_size = get_hparam('hidden_size')
114+
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
112115
self.note_midi_embed = Linear(1, hidden_size)
113116
self.note_dur_embed = Linear(1, hidden_size)
114117

@@ -136,8 +139,13 @@ def forward(self, note_midi, note_rest, note_dur, glide=None):
136139
:param glide: int64 [B, T_n]
137140
:return: [B, T_n, H]
138141
"""
139-
midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None]
140-
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
142+
if self.use_variance_scaling:
143+
midi_embed = self.note_midi_embed(note_midi[:, :, None] / 128)
144+
dur_embed = self.note_dur_embed(torch.log(1 + note_dur.float())[:, :, None])
145+
else:
146+
midi_embed = self.note_midi_embed(note_midi[:, :, None])
147+
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
148+
midi_embed *= ~note_rest[:, :, None]
141149
ornament_embed = 0
142150
if self.use_glide_embed:
143151
ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale

0 commit comments

Comments
 (0)