Skip to content

Commit 62bfa5a

Browse files
authored
fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf (#13503)
* fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf `fourier_filter` already upcasts `bfloat16` inputs to `float32` before calling `torch.fft.fftn`, because PyTorch's FFT does not support bf16. The same is true for `float16`: depending on the PyTorch version, `fftn` either - produces the experimental `torch.complex32` (ComplexHalf) dtype and emits a `UserWarning: ComplexHalf support is experimental`, or - raises `RuntimeError: Unsupported dtype Half` outright. Both paths were reachable from FreeU with half-precision models (e.g. `sd-turbo` + `fp16` + `enable_freeu`) as reported in #12504. Extend the existing upcast branch to cover `float16` too. The function already downcasts the result back to `x_in.dtype` at the end, so the externally observable dtype is unchanged. Closes #12504. * Address review: generalize upcast to non-float32 + fix ruff F821 - Apply @sayakpaul's suggestion: use `elif x.dtype != torch.float32:` so any non-float32 dtype (bf16, fp16, and future half-precision dtypes) is upcast to float32 before the FFT. - Drop the `"torch.Tensor"` return annotation on the test helper that triggered ruff F821 in CI (torch is imported inside the method body, not at module scope).
1 parent c8c8401 commit 62bfa5a

2 files changed

Lines changed: 47 additions & 2 deletions

File tree

src/diffusers/utils/torch_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
223223
# Non-power of 2 images must be float32
224224
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
225225
x = x.to(dtype=torch.float32)
226-
# fftn does not support bfloat16
227-
elif x.dtype == torch.bfloat16:
226+
# fftn does not support bfloat16, and produces the experimental ComplexHalf
227+
# dtype (torch.complex32) when given float16, which is numerically unstable
228+
# and triggers a UserWarning. Upcast any non-float32 dtype to float32.
229+
elif x.dtype != torch.float32:
228230
x = x.to(dtype=torch.float32)
229231

230232
# FFT

tests/others/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self):
204204
), f"Expected deprecation message substring not found, got: {messages}"
205205

206206

207+
class FourierFilterTester(unittest.TestCase):
208+
"""Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper)."""
209+
210+
def _run_without_complexhalf_warning(self, dtype):
211+
import torch
212+
213+
from diffusers.utils.torch_utils import fourier_filter
214+
215+
x = torch.randn(1, 4, 32, 32, dtype=dtype)
216+
with warnings.catch_warnings(record=True) as caught:
217+
warnings.simplefilter("always")
218+
out = fourier_filter(x, threshold=1, scale=0.5)
219+
220+
messages = [str(w.message) for w in caught]
221+
assert not any("ComplexHalf" in m for m in messages), (
222+
f"Unexpected ComplexHalf warning emitted by fourier_filter: {messages}"
223+
)
224+
return out
225+
226+
def test_fourier_filter_float16_no_complexhalf_warning(self):
227+
import torch
228+
229+
out = self._run_without_complexhalf_warning(torch.float16)
230+
assert out.dtype == torch.float16
231+
232+
def test_fourier_filter_bfloat16_no_complexhalf_warning(self):
233+
import torch
234+
235+
out = self._run_without_complexhalf_warning(torch.bfloat16)
236+
assert out.dtype == torch.bfloat16
237+
238+
def test_fourier_filter_preserves_dtype_and_shape(self):
239+
import torch
240+
241+
from diffusers.utils.torch_utils import fourier_filter
242+
243+
for dtype in (torch.float32, torch.float16, torch.bfloat16):
244+
x = torch.randn(2, 3, 16, 16, dtype=dtype)
245+
out = fourier_filter(x, threshold=1, scale=0.5)
246+
assert out.dtype == dtype
247+
assert out.shape == x.shape
248+
249+
207250
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
208251
class ExpectationsTester(unittest.TestCase):
209252
def test_expectations(self):

0 commit comments

Comments
 (0)