99from basics .base_dataset import BaseDataset
1010from basics .base_task import BaseTask
1111from modules .losses import DurationLoss , DiffusionLoss , RectifiedFlowLoss
12- from modules .metrics .curve import RawCurveAccuracy
13- from modules .metrics .duration import RhythmCorrectness , PhonemeDurationAccuracy
12+ from modules .metrics import (
13+ RawCurveAccuracy , RawCurveR2Score , RhythmCorrectness , PhonemeDurationAccuracy
14+ )
1415from modules .toplevel import DiffSingerVariance
1516from utils .hparams import hparams
1617from utils .plot import dur_to_figure , pitch_note_to_figure , curve_to_figure
@@ -144,6 +145,7 @@ def build_losses_and_metrics(self):
144145 raise ValueError (f'Unknown diffusion type: { self .diffusion_type } ' )
145146 self .register_validation_loss ('pitch_loss' )
146147 self .register_validation_metric ('pitch_acc' , RawCurveAccuracy (tolerance = 0.5 ))
148+ self .register_validation_metric ('pitch_r2' , RawCurveR2Score ())
147149 if self .predict_variances :
148150 if self .diffusion_type == 'ddpm' :
149151 self .var_loss = DiffusionLoss (loss_type = hparams ['main_loss_type' ])
@@ -154,6 +156,8 @@ def build_losses_and_metrics(self):
154156 else :
155157 raise ValueError (f'Unknown diffusion type: { self .diffusion_type } ' )
156158 self .register_validation_loss ('var_loss' )
159+ for name in self .variance_prediction_list :
160+ self .register_validation_metric (f'{ name } _r2' , RawCurveR2Score ())
157161
158162 def run_model (self , sample , infer = False ):
159163 spk_ids = sample ['spk_ids' ] if self .use_spk_id else None # [B,]
@@ -289,6 +293,8 @@ def sample_get(key, idx, abs_idx):
289293 variance_len = self .valid_dataset .metadata [name ][data_idx ]
290294 gt_variances = sample [name ][i ][:variance_len ].unsqueeze (0 )
291295 pred_variances = variances_preds [name ][i ][:variance_len ].unsqueeze (0 )
296+ mask = (sample_get ('mel2ph' , i , data_idx ) > 0 ) & ~ sample_get ('uv' , i , data_idx )
297+ self .valid_metrics [f'{ name } _r2' ].update (pred = pred_variances , target = gt_variances , mask = mask )
292298 self .plot_curve (
293299 data_idx ,
294300 gt_curve = gt_variances ,
0 commit comments