Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit b1c8f05

Browse files
authored
Adopt original computation of FID (#355)
* Adopt exact computation of FID Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> * Add verification of the imag part Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> --------- Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
1 parent 54ed235 commit b1c8f05

2 files changed

Lines changed: 28 additions & 69 deletions

File tree

generative/metrics/fid.py

Lines changed: 27 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,14 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
#
12-
# =========================================================================
13-
# Adapted from https://github.com/photosynthesis-team/piq
14-
# which has the following license:
15-
# https://github.com/photosynthesis-team/piq/blob/master/LICENSE
16-
#
17-
# Copyright 2023 photosynthesis-team. All rights reserved.
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
# =========================================================================
11+
3112

3213
from __future__ import annotations
3314

15+
import numpy as np
3416
import torch
3517
from monai.metrics.metric import Metric
18+
from scipy import linalg
3619

3720

3821
class FIDMetric(Metric):
@@ -70,77 +53,53 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
7053
return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)
7154

7255

73-
def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
56+
def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
7457
"""
7558
Estimate a covariance matrix of the variables.
7659
7760
Args:
78-
m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
61+
input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
7962
and each column a single observation of all those variables.
8063
rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
8164
Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
8265
observations.
8366
"""
84-
if m.dim() < 2:
85-
m = m.view(1, -1)
67+
if input_data.dim() < 2:
68+
input_data = input_data.view(1, -1)
8669

87-
if not rowvar and m.size(0) != 1:
88-
m = m.t()
70+
if not rowvar and input_data.size(0) != 1:
71+
input_data = input_data.t()
8972

90-
fact = 1.0 / (m.size(1) - 1)
91-
m = m - torch.mean(m, dim=1, keepdim=True)
92-
mt = m.t()
93-
return fact * m.matmul(mt).squeeze()
73+
factor = 1.0 / (input_data.size(1) - 1)
74+
input_data = input_data - torch.mean(input_data, dim=1, keepdim=True)
75+
return factor * input_data.matmul(input_data.t()).squeeze()
9476

9577

96-
def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]:
97-
"""
98-
Square root of matrix using Newton-Schulz Iterative method. Based on:
99-
https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in:
100-
https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303
101-
102-
Args:
103-
matrix: matrix or batch of matrices
104-
num_iters: Number of iteration of the method
105-
106-
"""
107-
dim = matrix.size(0)
108-
norm_of_matrix = matrix.norm(p="fro")
109-
y_matrix = matrix.div(norm_of_matrix)
110-
i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
111-
z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
112-
113-
s_matrix = torch.empty_like(matrix)
114-
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)
115-
116-
for _ in range(num_iters):
117-
t = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
118-
y_matrix = y_matrix.mm(t)
119-
z_matrix = t.mm(z_matrix)
120-
121-
s_matrix = y_matrix * torch.sqrt(norm_of_matrix)
122-
123-
norm_of_matrix = torch.norm(matrix)
124-
error = matrix - torch.mm(s_matrix, s_matrix)
125-
error = torch.norm(error) / norm_of_matrix
126-
127-
if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5):
128-
break
129-
130-
return s_matrix, error
78+
def _sqrtm(input_data: torch.Tensor) -> torch.Tensor:
79+
"""Compute the square root of a matrix."""
80+
scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False)
81+
return torch.from_numpy(scipy_res)
13182

13283

13384
def compute_frechet_distance(
13485
mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
13586
) -> torch.Tensor:
13687
"""The Frechet distance between multivariate normal distributions."""
13788
diff = mu_x - mu_y
138-
covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y))
13989

140-
# If calculation produces singular product, epsilon is added to diagonal of cov estimates
90+
covmean = _sqrtm(sigma_x.mm(sigma_y))
91+
92+
# Product might be almost singular
14193
if not torch.isfinite(covmean).all():
94+
print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates")
14295
offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
143-
covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset))
96+
covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset))
97+
98+
# Numerical error might give slight imaginary component
99+
if torch.is_complex(covmean):
100+
if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3):
101+
raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.")
102+
covmean = covmean.real
144103

145104
tr_covmean = torch.trace(covmean)
146105
return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean

tests/test_compute_fid_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_results(self):
2424
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
2525
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
2626
results = FIDMetric()(x, y)
27-
np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4)
27+
np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4)
2828

2929
def test_input_dimensions(self):
3030
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)