Skip to content

Commit b0ae9ca

Browse files
authored
[DONE]Supplement the Variance Model Scaling / Retake Scaling / Conditioner cache on LYNXNet2 (#259)
* Supplement the Variance Model Scaling / Retake Scaling / Conditioner cache on LYNXNet2 * Update toplevel.py * del use_retake_scaling
1 parent 2ea898f commit b0ae9ca

7 files changed

Lines changed: 67 additions & 15 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ backbone_args:
7777
num_layers: 6
7878
kernel_size: 31
7979
dropout_rate: 0.0
80+
use_conditioner_cache: true
8081
main_loss_type: l2
8182
main_loss_log_norm: false
8283
schedule_type: 'linear'

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ backbone_args:
8383
num_layers: 6
8484
kernel_size: 31
8585
dropout_rate: 0.0
86+
use_conditioner_cache: true
8687
#backbone_type: 'wavenet'
8788
#backbone_args:
8889
# num_channels: 512

configs/templates/config_variance.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ pitch_prediction_args:
105105
num_layers: 6
106106
num_channels: 512
107107
dropout_rate: 0.0
108+
use_conditioner_cache: true
108109

109110
variances_prediction_args:
110111
total_repeat_bins: 48
@@ -118,6 +119,7 @@ variances_prediction_args:
118119
num_layers: 6
119120
num_channels: 384
120121
dropout_rate: 0.0
122+
use_conditioner_cache: true
121123

122124
lambda_dur_loss: 1.0
123125
lambda_pitch_loss: 1.0

configs/variance.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pitch_prediction_args:
7171
num_layers: 6
7272
num_channels: 512
7373
dropout_rate: 0.0
74+
use_conditioner_cache: true
7475

7576
energy_db_min: -96.0
7677
energy_db_max: -12.0
@@ -94,6 +95,7 @@ variances_prediction_args:
9495
num_layers: 6
9596
num_channels: 384
9697
dropout_rate: 0.0
98+
use_conditioner_cache: true
9799

98100
lambda_dur_loss: 1.0
99101
lambda_pitch_loss: 1.0

deployment/modules/toplevel.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,16 @@ def forward_pitch_preprocess(
252252
base_pitch = self.smooth(frame_midi_pitch)
253253
if self.use_melody_encoder:
254254
delta_pitch = (pitch - base_pitch) * ~retake
255-
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
255+
if self.use_variance_scaling:
256+
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None] / 12)
257+
else:
258+
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
256259
else:
257260
base_pitch = base_pitch * retake + pitch * ~retake
258-
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
261+
if self.use_variance_scaling:
262+
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128)
263+
else:
264+
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
259265
if hparams['use_spk_id'] and spk_embed is not None:
260266
pitch_cond += spk_embed
261267
return pitch_cond, base_pitch
@@ -275,13 +281,16 @@ def forward_variance_preprocess(
275281
variances: dict = None, retake=None, spk_embed=None
276282
):
277283
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
278-
variance_cond = condition + self.pitch_embed(pitch[:, :, None])
284+
if self.use_variance_scaling:
285+
variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
286+
else:
287+
variance_cond = condition + self.pitch_embed(pitch[:, :, None])
279288
non_retake_masks = [
280289
v_retake.float() # [B, T, 1]
281290
for v_retake in (~retake).split(1, dim=2)
282291
]
283292
variance_embeds = [
284-
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks
293+
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name]
285294
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
286295
]
287296
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)

modules/backbones/lynxnet2.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,20 @@ def forward(self, x):
3333

3434
class LYNXNet2(nn.Module):
3535
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31,
36-
dropout=0.0):
36+
dropout=0.0, use_conditioner_cache=False):
3737
"""
3838
LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
3939
"""
4040
super().__init__()
4141
self.in_dims = in_dims
4242
self.n_feats = n_feats
4343
self.input_projection = nn.Linear(in_dims * n_feats, num_channels)
44-
self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels)
45-
# It may need to be modified at some point to be compatible with the condition cache
46-
# self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1)
44+
self.use_conditioner_cache = use_conditioner_cache
45+
if self.use_conditioner_cache:
46+
# It may need to be modified at some point to be compatible with the condition cache
47+
self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1)
48+
else:
49+
self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels)
4750
self.diffusion_embedding = nn.Sequential(
4851
SinusoidalPosEmb(num_channels),
4952
nn.Linear(num_channels, num_channels * 4),
@@ -81,9 +84,11 @@ def forward(self, spec, diffusion_step, cond):
8184
x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T]
8285

8386
x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M]
84-
x = x + self.conditioner_projection(cond.transpose(1, 2))
85-
# It may need to be modified at some point to be compatible with the condition cache
86-
# x = x + self.conditioner_projection(cond.transpose(1, 2))
87+
if self.use_conditioner_cache:
88+
# It may need to be modified at some point to be compatible with the condition cache
89+
x = x + self.conditioner_projection(cond).transpose(1, 2)
90+
else:
91+
x = x + self.conditioner_projection(cond.transpose(1, 2))
8792
x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1)
8893

8994
for layer in self.residual_layers:

modules/toplevel.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,28 @@ def __init__(self, vocab_size):
195195
else:
196196
raise NotImplementedError(self.diffusion_type)
197197

198+
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
199+
self.custom_variance_scaling_factor = {
200+
'energy': 1. / 96,
201+
'breathiness': 1. / 96,
202+
'voicing': 1. / 96,
203+
'tension': 0.1,
204+
'key_shift': 1. / 12,
205+
'speed': 1.
206+
}
207+
self.default_variance_scaling_factor = {
208+
'energy': 1.,
209+
'breathiness': 1.,
210+
'voicing': 1.,
211+
'tension': 1.,
212+
'key_shift': 1.,
213+
'speed': 1.
214+
}
215+
if self.use_variance_scaling:
216+
self.variance_retake_scaling = self.custom_variance_scaling_factor
217+
else:
218+
self.variance_retake_scaling = self.default_variance_scaling_factor
219+
198220
def forward(
199221
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
200222
note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None,
@@ -271,11 +293,17 @@ def forward(
271293
delta_pitch_in = torch.zeros_like(base_pitch)
272294
else:
273295
delta_pitch_in = (pitch - base_pitch) * ~pitch_retake
274-
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
296+
if self.use_variance_scaling:
297+
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None] / 12)
298+
else:
299+
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
275300
else:
276301
if not retake_unset: # retake
277302
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake
278-
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
303+
if self.use_variance_scaling:
304+
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128)
305+
else:
306+
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
279307

280308
if infer:
281309
pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True)
@@ -289,12 +317,16 @@ def forward(
289317

290318
if pitch is None:
291319
pitch = base_pitch + pitch_pred_out
292-
var_cond = condition + self.pitch_embed(pitch[:, :, None])
320+
if self.use_variance_scaling:
321+
var_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
322+
else:
323+
var_cond = condition + self.pitch_embed(pitch[:, :, None])
293324

294325
variance_inputs = self.collect_variance_inputs(**kwargs)
326+
295327
if variance_retake is not None:
296328
variance_embeds = [
297-
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None]
329+
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name]
298330
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
299331
]
300332
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)

0 commit comments

Comments
 (0)