Skip to content

Commit 2ea898f

Browse files
committed
variance scaling for onnx
1 parent 954e41c commit 2ea898f

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

deployment/modules/fastspeech2.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def forward(
7575
mel2ph = self.lr(durations)
7676
f0 = f0 * (mel2ph > 0)
7777
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
78-
dur_embed = self.dur_embed(durations.float()[:, :, None])
78+
if self.use_variance_scaling:
79+
dur_embed = self.dur_embed(torch.log(1 + durations.float())[:, :, None])
80+
else:
81+
dur_embed = self.dur_embed(durations.float()[:, :, None])
7982
if self.use_lang_id:
8083
lang_mask = torch.any(
8184
tokens[..., None] == self.cross_lingual_token_idx[None, None],
@@ -99,7 +102,8 @@ def forward(
99102

100103
if self.use_variance_embeds:
101104
variance_embeds = torch.stack([
102-
self.variance_embeds[v_name](variances[v_name][:, :, None])
105+
self.variance_embeds[v_name](variances[v_name][:, :, None])
106+
* self.variance_scaling_factor[v_name]
103107
for v_name in self.variance_embed_list
104108
], dim=-1).sum(-1)
105109
condition += variance_embeds
@@ -112,6 +116,7 @@ def forward(
112116
gender_mask = (gender < 0.).float()
113117
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
114118
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
119+
key_shift_embed *= self.variance_scaling_factor['key_shift']
115120
condition += key_shift_embed
116121

117122
if hparams['use_speed_embed']:
@@ -120,6 +125,7 @@ def forward(
120125
speed_embed = self.speed_embed(velocity[:, :, None])
121126
else:
122127
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
128+
speed_embed *= self.variance_scaling_factor['speed']
123129
condition += speed_embed
124130

125131
if hparams['use_spk_id']:
@@ -162,7 +168,10 @@ def forward_encoder_word(self, tokens, word_div, word_dur, languages=None):
162168

163169
def forward_encoder_phoneme(self, tokens, ph_dur, languages=None):
164170
txt_embed = self.txt_embed(tokens)
165-
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
171+
if self.use_variance_scaling:
172+
ph_dur_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None])
173+
else:
174+
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
166175
if self.use_lang_id:
167176
lang_mask = torch.any(
168177
tokens[..., None] == self.cross_lingual_token_idx[None, None],

0 commit comments

Comments
 (0)