Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 54ed235

Browse files
authored
Rename Metric and use double in computation (#350)
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
1 parent 49e5545 commit 54ed235

3 files changed

Lines changed: 7 additions & 7 deletions

File tree

generative/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from .fid import FID
14+
from .fid import FIDMetric
1515
from .mmd import MMD
1616
from .ms_ssim import MultiScaleSSIMMetric
1717
from .ssim import SSIMMetric

generative/metrics/fid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from monai.metrics.metric import Metric
3636

3737

38-
class FID(Metric):
38+
class FIDMetric(Metric):
3939
"""
4040
Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
4141
Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
@@ -56,8 +56,8 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
5656

5757

5858
def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
59-
y = y.float()
60-
y_pred = y_pred.float()
59+
y = y.double()
60+
y_pred = y_pred.double()
6161

6262
if y.ndimension() > 2:
6363
raise ValueError("Inputs should have (number images, number of features) shape.")

tests/test_compute_fid_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616
import numpy as np
1717
import torch
1818

19-
from generative.metrics import FID
19+
from generative.metrics import FIDMetric
2020

2121

2222
class TestMMDMetric(unittest.TestCase):
2323
def test_results(self):
2424
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
2525
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
26-
results = FID()(x, y)
26+
results = FIDMetric()(x, y)
2727
np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4)
2828

2929
def test_input_dimensions(self):
3030
with self.assertRaises(ValueError):
31-
FID()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
31+
FIDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
3232

3333

3434
if __name__ == "__main__":

0 commit comments

Comments
 (0)