Skip to content

Commit c315d38

Browse files
authored
NeoX style RoPE (#277)
* refactor RoPE refactor RoPE * NeoX style RoPE * fix export ONNX model before RoPE refactor
1 parent 1638ccd commit c315d38

9 files changed

Lines changed: 40 additions & 6 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ timesteps: 1000
6464
max_beta: 0.02
6565
enc_ffn_kernel_size: 3
6666
use_rope: true
67+
rope_interleaved: false
6768
use_stretch_embed: true
6869
use_variance_scaling: true
6970
rel_pos: true

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ augmentation_args:
7171
diffusion_type: reflow
7272
enc_ffn_kernel_size: 3
7373
use_rope: true
74+
rope_interleaved: false
7475
use_stretch_embed: true
7576
use_variance_scaling: true
7677
use_shallow_diffusion: true

configs/templates/config_variance.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ tension_logit_max: 10.0
6565

6666
enc_ffn_kernel_size: 3
6767
use_rope: true
68+
rope_interleaved: false
6869
use_stretch_embed: false
6970
use_variance_scaling: true
7071
hidden_size: 384

configs/variance.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ predict_tension: false
3636

3737
enc_ffn_kernel_size: 3
3838
use_rope: true
39+
rope_interleaved: false
3940
use_stretch_embed: false
4041
use_variance_scaling: true
4142
rel_pos: true

deployment/exporters/acoustic_exporter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
from pathlib import Path
33
from typing import List, Union, Tuple, Dict
4+
import warnings
45

56
import onnx
67
import onnxsim
@@ -78,6 +79,7 @@ def __init__(
7879
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
7980
if self.freeze_spk is not None:
8081
self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
82+
self.rope_interleaved = hparams.get('rope_interleaved', None)
8183

8284
def build_model(self) -> DiffSingerAcousticONNX:
8385
model = DiffSingerAcousticONNX(
@@ -88,8 +90,21 @@ def build_model(self) -> DiffSingerAcousticONNX:
8890
for p in self.phoneme_dictionary.cross_lingual_phonemes
8991
})
9092
).eval().to(self.device)
93+
if self.rope_interleaved is None:
94+
warnings.warn(
95+
"After RoPE is refactored, the checkpoint no longer contains relevant parameters. "
96+
"(https://github.com/openvpi/DiffSinger/pull/276)"
97+
"In order to export ONNX with behavior compatible with past checkpoints, "
98+
"it will be set to 'strict=False', which will no longer check the validity of the checkpoint. "
99+
"Please understand what you are doing.",
100+
UserWarning,
101+
stacklevel=2
102+
)
103+
strict=False
104+
else:
105+
strict=True
91106
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
92-
prefix_in_ckpt='model', strict=True, device=self.device)
107+
prefix_in_ckpt='model', strict=strict, device=self.device)
93108
return model
94109

95110
def export(self, path: Path):

deployment/exporters/variance_exporter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
from pathlib import Path
33
from typing import Union, List, Tuple, Dict
4+
import warnings
45

56
import onnx
67
import onnxsim
@@ -81,6 +82,7 @@ def __init__(
8182
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
8283
if self.freeze_spk is not None:
8384
self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
85+
self.rope_interleaved = hparams.get('rope_interleaved', None)
8486

8587
def build_model(self) -> DiffSingerVarianceONNX:
8688
model = DiffSingerVarianceONNX(
@@ -90,6 +92,19 @@ def build_model(self) -> DiffSingerVarianceONNX:
9092
for p in self.phoneme_dictionary.cross_lingual_phonemes
9193
})
9294
).eval().to(self.device)
95+
if self.rope_interleaved is None:
96+
warnings.warn(
97+
"After RoPE is refactored, the checkpoint no longer contains relevant parameters. "
98+
"(https://github.com/openvpi/DiffSinger/pull/276)"
99+
"In order to export ONNX with behavior compatible with past checkpoints, "
100+
"it will be set to 'strict=False', which will no longer check the validity of the checkpoint. "
101+
"Please understand what you are doing.",
102+
UserWarning,
103+
stacklevel=2
104+
)
105+
strict=False
106+
else:
107+
strict=True
93108
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
94109
prefix_in_ckpt='model', strict=True, device=self.device)
95110
model.build_smooth_op(self.device)

modules/fastspeech/acoustic_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, vocab_size):
3838
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
3939
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
4040
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
41-
use_rope=hparams.get('use_rope', False)
41+
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
4242
)
4343

4444
self.pitch_embed = Linear(1, hparams['hidden_size'])

modules/fastspeech/tts_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,14 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
369369
class FastSpeech2Encoder(nn.Module):
370370
def __init__(self, hidden_size, num_layers,
371371
ffn_kernel_size=9, ffn_act='gelu',
372-
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False):
372+
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True):
373373
super().__init__()
374374
self.num_layers = num_layers
375375
embed_dim = self.hidden_size = hidden_size
376376
self.dropout = dropout
377377
self.use_pos_embed = use_pos_embed
378378
if use_pos_embed and use_rope:
379-
rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads)
379+
rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads, interleaved = rope_interleaved)
380380
else:
381381
rotary_embed = None
382382
self.layers = nn.ModuleList([

modules/fastspeech/variance_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, vocab_size):
3333
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
3434
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
3535
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
36-
use_rope=hparams.get('use_rope', False)
36+
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
3737
)
3838

3939
dur_hparams = hparams['dur_prediction_args']
@@ -127,7 +127,7 @@ def get_hparam(key):
127127
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'),
128128
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'),
129129
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'),
130-
use_rope=get_hparam('use_rope')
130+
use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True)
131131
)
132132
self.out_proj = Linear(hidden_size, hparams['hidden_size'])
133133

0 commit comments

Comments
 (0)