@@ -17,7 +17,7 @@ def __init__(self, vocab_size):
1717 self .predict_dur = hparams ['predict_dur' ]
1818 self .linguistic_mode = 'word' if hparams ['predict_dur' ] else 'phoneme'
1919 self .use_lang_id = hparams ['use_lang_id' ]
20-
20+ self . use_variance_scaling = hparams . get ( 'use_variance_scaling' , False )
2121 self .txt_embed = Embedding (vocab_size , hparams ['hidden_size' ], PAD_INDEX )
2222 if self .use_lang_id :
2323 self .lang_embed = Embedding (hparams ['num_lang' ] + 1 , hparams ['hidden_size' ], padding_idx = 0 )
@@ -80,9 +80,11 @@ def forward(
8080 word_dur = torch .gather (F .pad (word_dur , [1 , 0 ], value = 0 ), 1 , ph2word ) # [B, T_w] => [B, T_ph]
8181 word_dur_embed = self .word_dur_embed (word_dur .float ()[:, :, None ])
8282 extra_embed = onset_embed + word_dur_embed
83+ elif self .use_variance_scaling :
84+ extra_embed = self .ph_dur_embed (torch .log (1 + ph_dur .float ())[:, :, None ])
8385 else :
84- ph_dur_embed = self .ph_dur_embed (ph_dur .float ()[:, :, None ])
85- extra_embed = ph_dur_embed
86+ extra_embed = self .ph_dur_embed (ph_dur .float ()[:, :, None ])
87+
8688 if self .use_lang_id :
8789 lang_embed = self .lang_embed (languages )
8890 extra_embed += lang_embed
@@ -109,6 +111,7 @@ def get_hparam(key):
109111
110112 # MIDI inputs
111113 hidden_size = get_hparam ('hidden_size' )
114+ self .use_variance_scaling = hparams .get ('use_variance_scaling' , False )
112115 self .note_midi_embed = Linear (1 , hidden_size )
113116 self .note_dur_embed = Linear (1 , hidden_size )
114117
@@ -136,8 +139,13 @@ def forward(self, note_midi, note_rest, note_dur, glide=None):
136139 :param glide: int64 [B, T_n]
137140 :return: [B, T_n, H]
138141 """
139- midi_embed = self .note_midi_embed (note_midi [:, :, None ]) * ~ note_rest [:, :, None ]
140- dur_embed = self .note_dur_embed (note_dur .float ()[:, :, None ])
142+ if self .use_variance_scaling :
143+ midi_embed = self .note_midi_embed (note_midi [:, :, None ] / 128 )
144+ dur_embed = self .note_dur_embed (torch .log (1 + note_dur .float ())[:, :, None ])
145+ else :
146+ midi_embed = self .note_midi_embed (note_midi [:, :, None ])
147+ dur_embed = self .note_dur_embed (note_dur .float ()[:, :, None ])
148+ midi_embed *= ~ note_rest [:, :, None ]
141149 ornament_embed = 0
142150 if self .use_glide_embed :
143151 ornament_embed += self .note_glide_embed (glide ) * self .glide_embed_scale
0 commit comments