@@ -195,6 +195,28 @@ def __init__(self, vocab_size):
195195 else :
196196 raise NotImplementedError (self .diffusion_type )
197197
198+ self .use_variance_scaling = hparams .get ('use_variance_scaling' , False )
199+ self .custom_variance_scaling_factor = {
200+ 'energy' : 1. / 96 ,
201+ 'breathiness' : 1. / 96 ,
202+ 'voicing' : 1. / 96 ,
203+ 'tension' : 0.1 ,
204+ 'key_shift' : 1. / 12 ,
205+ 'speed' : 1.
206+ }
207+ self .default_variance_scaling_factor = {
208+ 'energy' : 1. ,
209+ 'breathiness' : 1. ,
210+ 'voicing' : 1. ,
211+ 'tension' : 1. ,
212+ 'key_shift' : 1. ,
213+ 'speed' : 1.
214+ }
215+ if self .use_variance_scaling :
216+ self .variance_retake_scaling = self .custom_variance_scaling_factor
217+ else :
218+ self .variance_retake_scaling = self .default_variance_scaling_factor
219+
198220 def forward (
199221 self , txt_tokens , midi , ph2word , ph_dur = None , word_dur = None , mel2ph = None ,
200222 note_midi = None , note_rest = None , note_dur = None , note_glide = None , mel2note = None ,
@@ -271,11 +293,17 @@ def forward(
271293 delta_pitch_in = torch .zeros_like (base_pitch )
272294 else :
273295 delta_pitch_in = (pitch - base_pitch ) * ~ pitch_retake
274- pitch_cond += self .delta_pitch_embed (delta_pitch_in [:, :, None ])
296+ if self .use_variance_scaling :
297+ pitch_cond += self .delta_pitch_embed (delta_pitch_in [:, :, None ] / 12 )
298+ else :
299+ pitch_cond += self .delta_pitch_embed (delta_pitch_in [:, :, None ])
275300 else :
276301 if not retake_unset : # retake
277302 base_pitch = base_pitch * pitch_retake + pitch * ~ pitch_retake
278- pitch_cond += self .base_pitch_embed (base_pitch [:, :, None ])
303+ if self .use_variance_scaling :
304+ pitch_cond += self .base_pitch_embed (base_pitch [:, :, None ] / 128 )
305+ else :
306+ pitch_cond += self .base_pitch_embed (base_pitch [:, :, None ])
279307
280308 if infer :
281309 pitch_pred_out = self .pitch_predictor (pitch_cond , infer = True )
@@ -289,12 +317,16 @@ def forward(
289317
290318 if pitch is None :
291319 pitch = base_pitch + pitch_pred_out
292- var_cond = condition + self .pitch_embed (pitch [:, :, None ])
320+ if self .use_variance_scaling :
321+ var_cond = condition + self .pitch_embed (pitch [:, :, None ] / 12 )
322+ else :
323+ var_cond = condition + self .pitch_embed (pitch [:, :, None ])
293324
294325 variance_inputs = self .collect_variance_inputs (** kwargs )
326+
295327 if variance_retake is not None :
296328 variance_embeds = [
297- self .variance_embeds [v_name ](v_input [:, :, None ]) * ~ variance_retake [v_name ][:, :, None ]
329+ self .variance_embeds [v_name ](v_input [:, :, None ]) * ~ variance_retake [v_name ][:, :, None ] * self . variance_retake_scaling [ v_name ]
298330 for v_name , v_input in zip (self .variance_prediction_list , variance_inputs )
299331 ]
300332 var_cond += torch .stack (variance_embeds , dim = - 1 ).sum (- 1 )
0 commit comments