Commit 62bfa5a
authored
fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf (#13503)
* fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf
`fourier_filter` already upcasts `bfloat16` inputs to `float32` before
calling `torch.fft.fftn`, because PyTorch's FFT does not support bf16.
The same is true for `float16`: depending on the PyTorch version,
`fftn` either
- produces the experimental `torch.complex32` (ComplexHalf) dtype and
emits a `UserWarning: ComplexHalf support is experimental`, or
- raises `RuntimeError: Unsupported dtype Half` outright.
Both paths were reachable from FreeU with half-precision models
(e.g. `sd-turbo` + `fp16` + `enable_freeu`) as reported in #12504.
Extend the existing upcast branch to cover `float16` too. The function
already downcasts the result back to `x_in.dtype` at the end, so the
externally observable dtype is unchanged.
Closes #12504.
* Address review: generalize upcast to non-float32 + fix ruff F821
- Apply @sayakpaul's suggestion: use `elif x.dtype != torch.float32:`
so any non-float32 dtype (bf16, fp16, and future half-precision
dtypes) is upcast to float32 before the FFT.
- Drop the `"torch.Tensor"` return annotation on the test helper
that triggered ruff F821 in CI (torch is imported inside the
method body, not at module scope).1 parent c8c8401 commit 62bfa5a
2 files changed
Lines changed: 47 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
223 | 223 | | |
224 | 224 | | |
225 | 225 | | |
226 | | - | |
227 | | - | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
228 | 230 | | |
229 | 231 | | |
230 | 232 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
204 | 204 | | |
205 | 205 | | |
206 | 206 | | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
207 | 250 | | |
208 | 251 | | |
209 | 252 | | |
| |||
0 commit comments