11import json
22from pathlib import Path
33from typing import List , Union , Tuple , Dict
4+ import warnings
45
56import onnx
67import onnxsim
@@ -78,6 +79,7 @@ def __init__(
7879 self .export_spk = [(name , {name : 1.0 }) for name in self .spk_map .keys ()]
7980 if self .freeze_spk is not None :
8081 self .model .fs2 .register_buffer ('frozen_spk_embed' , self ._perform_spk_mix (self .freeze_spk [1 ]))
82+ self .rope_interleaved = hparams .get ('rope_interleaved' , None )
8183
8284 def build_model (self ) -> DiffSingerAcousticONNX :
8385 model = DiffSingerAcousticONNX (
@@ -88,8 +90,21 @@ def build_model(self) -> DiffSingerAcousticONNX:
8890 for p in self .phoneme_dictionary .cross_lingual_phonemes
8991 })
9092 ).eval ().to (self .device )
93+ if self .rope_interleaved is None :
94+ warnings .warn (
95+ "After RoPE is refactored, the checkpoint no longer contains relevant parameters. "
96+ "(https://github.com/openvpi/DiffSinger/pull/276)"
97+ "In order to export ONNX with behavior compatible with past checkpoints, "
98+ "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. "
99+ "Please understand what you are doing." ,
100+ UserWarning ,
101+ stacklevel = 2
102+ )
103+ strict = False
104+ else :
105+ strict = True
91106 load_ckpt (model , hparams ['work_dir' ], ckpt_steps = self .ckpt_steps ,
92- prefix_in_ckpt = 'model' , strict = True , device = self .device )
107+ prefix_in_ckpt = 'model' , strict = strict , device = self .device )
93108 return model
94109
95110 def export (self , path : Path ):
0 commit comments