@@ -75,7 +75,10 @@ def forward(
7575 mel2ph = self .lr (durations )
7676 f0 = f0 * (mel2ph > 0 )
7777 mel2ph = mel2ph [..., None ].repeat ((1 , 1 , hparams ['hidden_size' ]))
78- dur_embed = self .dur_embed (durations .float ()[:, :, None ])
78+ if self .use_variance_scaling :
79+ dur_embed = self .dur_embed (torch .log (1 + durations .float ())[:, :, None ])
80+ else :
81+ dur_embed = self .dur_embed (durations .float ()[:, :, None ])
7982 if self .use_lang_id :
8083 lang_mask = torch .any (
8184 tokens [..., None ] == self .cross_lingual_token_idx [None , None ],
@@ -99,7 +102,8 @@ def forward(
99102
100103 if self .use_variance_embeds :
101104 variance_embeds = torch .stack ([
102- self .variance_embeds [v_name ](variances [v_name ][:, :, None ])
105+ self .variance_embeds [v_name ](variances [v_name ][:, :, None ])
106+ * self .variance_scaling_factor [v_name ]
103107 for v_name in self .variance_embed_list
104108 ], dim = - 1 ).sum (- 1 )
105109 condition += variance_embeds
@@ -112,6 +116,7 @@ def forward(
112116 gender_mask = (gender < 0. ).float ()
113117 key_shift = gender * ((1. - gender_mask ) * self .shift_max + gender_mask * abs (self .shift_min ))
114118 key_shift_embed = self .key_shift_embed (key_shift [:, :, None ])
119+ key_shift_embed *= self .variance_scaling_factor ['key_shift' ]
115120 condition += key_shift_embed
116121
117122 if hparams ['use_speed_embed' ]:
@@ -120,6 +125,7 @@ def forward(
120125 speed_embed = self .speed_embed (velocity [:, :, None ])
121126 else :
122127 speed_embed = self .speed_embed (torch .FloatTensor ([1. ]).to (condition .device )[:, None , None ])
128+ speed_embed *= self .variance_scaling_factor ['speed' ]
123129 condition += speed_embed
124130
125131 if hparams ['use_spk_id' ]:
@@ -162,7 +168,10 @@ def forward_encoder_word(self, tokens, word_div, word_dur, languages=None):
162168
163169 def forward_encoder_phoneme (self , tokens , ph_dur , languages = None ):
164170 txt_embed = self .txt_embed (tokens )
165- ph_dur_embed = self .ph_dur_embed (ph_dur .float ()[:, :, None ])
171+ if self .use_variance_scaling :
172+ ph_dur_embed = self .ph_dur_embed (torch .log (1 + ph_dur .float ())[:, :, None ])
173+ else :
174+ ph_dur_embed = self .ph_dur_embed (ph_dur .float ()[:, :, None ])
166175 if self .use_lang_id :
167176 lang_mask = torch .any (
168177 tokens [..., None ] == self .cross_lingual_token_idx [None , None ],
0 commit comments