Skip to content

Commit cd2d0a5

Browse files
committed
Add R2Score metrics
1 parent 44ce312 commit cd2d0a5

3 files changed

Lines changed: 50 additions & 2 deletions

File tree

modules/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .curve import RawCurveAccuracy, RawCurveR2Score
2+
from .duration import RhythmCorrectness, PhonemeDurationAccuracy

modules/metrics/curve.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,43 @@ def update(self, pred: Tensor, target: Tensor, mask=None) -> None:
3131

3232
def compute(self) -> Tensor:
3333
return self.close / self.total
34+
35+
36+
class RawCurveR2Score(torchmetrics.Metric):
37+
def __init__(self, **kwargs):
38+
super().__init__(**kwargs)
39+
self.add_state('sum_squared_error', default=torch.tensor(0.0), dist_reduce_fx='sum')
40+
self.add_state('sum_error', default=torch.tensor(0.0), dist_reduce_fx='sum')
41+
self.add_state('residual', default=torch.tensor(0.0), dist_reduce_fx='sum')
42+
self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')
43+
44+
def update(self, pred: Tensor, target: Tensor, mask=None) -> None:
45+
"""
46+
47+
:param pred: predicted curve
48+
:param target: reference curve
49+
:param mask: valid or non-padding mask
50+
"""
51+
if mask is None:
52+
assert pred.shape == target.shape, f'shapes of pred and target mismatch: {pred.shape}, {target.shape}'
53+
else:
54+
assert pred.shape == target.shape == mask.shape, \
55+
f'shapes of pred, target and mask mismatch: {pred.shape}, {target.shape}, {mask.shape}'
56+
pred = pred[mask]
57+
target = target[mask]
58+
pred = pred.flatten()
59+
target = target.flatten()
60+
61+
sum_error = torch.sum(target)
62+
sum_squared_error = torch.sum(target * target)
63+
residual = target - pred
64+
rss = torch.sum(residual * residual)
65+
total = target.numel() if mask is None else mask.sum()
66+
67+
self.sum_squared_error += sum_squared_error
68+
self.sum_error += sum_error
69+
self.residual += rss
70+
self.total += total
71+
72+
def compute(self) -> Tensor:
73+
return 1 - self.residual / (self.sum_squared_error - self.sum_error ** 2 / self.total)

training/variance_task.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from basics.base_dataset import BaseDataset
1010
from basics.base_task import BaseTask
1111
from modules.losses import DurationLoss, DiffusionLoss, RectifiedFlowLoss
12-
from modules.metrics.curve import RawCurveAccuracy
13-
from modules.metrics.duration import RhythmCorrectness, PhonemeDurationAccuracy
12+
from modules.metrics import (
13+
RawCurveAccuracy, RawCurveR2Score, RhythmCorrectness, PhonemeDurationAccuracy
14+
)
1415
from modules.toplevel import DiffSingerVariance
1516
from utils.hparams import hparams
1617
from utils.plot import dur_to_figure, pitch_note_to_figure, curve_to_figure
@@ -144,6 +145,7 @@ def build_losses_and_metrics(self):
144145
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}')
145146
self.register_validation_loss('pitch_loss')
146147
self.register_validation_metric('pitch_acc', RawCurveAccuracy(tolerance=0.5))
148+
self.register_validation_metric('pitch_r2', RawCurveR2Score())
147149
if self.predict_variances:
148150
if self.diffusion_type == 'ddpm':
149151
self.var_loss = DiffusionLoss(loss_type=hparams['main_loss_type'])
@@ -154,6 +156,8 @@ def build_losses_and_metrics(self):
154156
else:
155157
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}')
156158
self.register_validation_loss('var_loss')
159+
for name in self.variance_prediction_list:
160+
self.register_validation_metric(f'{name}_r2', RawCurveR2Score())
157161

158162
def run_model(self, sample, infer=False):
159163
spk_ids = sample['spk_ids'] if self.use_spk_id else None # [B,]
@@ -289,6 +293,8 @@ def sample_get(key, idx, abs_idx):
289293
variance_len = self.valid_dataset.metadata[name][data_idx]
290294
gt_variances = sample[name][i][:variance_len].unsqueeze(0)
291295
pred_variances = variances_preds[name][i][:variance_len].unsqueeze(0)
296+
mask = (sample_get('mel2ph', i, data_idx) > 0) & ~sample_get('uv', i, data_idx)
297+
self.valid_metrics[f'{name}_r2'].update(pred=pred_variances, target=gt_variances, mask=mask)
292298
self.plot_curve(
293299
data_idx,
294300
gt_curve=gt_variances,

0 commit comments

Comments
 (0)