Skip to content

Commit e18fcc9

Browse files
authored
Merge pull request #224 from IntelPython/add-mkl-fft-patching
add manual NumPy patching to `mkl_fft`
2 parents 62dc6b8 + a162685 commit e18fcc9

15 files changed

Lines changed: 278 additions & 13 deletions

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [dev] - YYYY-MM-DD
88

9+
### Added
10+
* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224)
11+
912
### Removed
1013
* Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)
1114

mkl_fft/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without
@@ -39,9 +38,15 @@
3938
rfft2,
4039
rfftn,
4140
)
41+
from ._patch_numpy import (
42+
is_patched,
43+
mkl_fft,
44+
patch_numpy_fft,
45+
restore_numpy_fft,
46+
)
4247
from ._version import __version__
4348

44-
import mkl_fft.interfaces # isort: skip
49+
from mkl_fft import interfaces # isort: skip
4550

4651
__all__ = [
4752
"fft",
@@ -57,6 +62,10 @@
5762
"rfftn",
5863
"irfftn",
5964
"interfaces",
65+
"mkl_fft",
66+
"patch_numpy_fft",
67+
"restore_numpy_fft",
68+
"is_patched",
6069
]
6170

6271
del _init_helper

mkl_fft/_fft_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2025, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/_mkl_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2025, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/_patch_numpy.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2017, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
"""Define functions for patching NumPy with MKL-based NumPy interface."""
27+
28+
from contextlib import ContextDecorator
29+
from threading import Lock, local
30+
31+
import numpy as np
32+
33+
import mkl_fft.interfaces.numpy_fft as _nfft
34+
35+
36+
class _GlobalPatch:
37+
def __init__(self):
38+
self._lock = Lock()
39+
self._patch_count = 0
40+
self._restore_dict = {}
41+
# make _patched_functions a tuple (immutable)
42+
self._patched_functions = tuple(_nfft.__all__)
43+
self._tls = local()
44+
45+
def _register_func(self, name, func):
46+
if name not in self._patched_functions:
47+
raise ValueError(f"{name} not an mkl_fft function.")
48+
if name not in self._restore_dict:
49+
self._restore_dict[name] = getattr(np.fft, name)
50+
setattr(np.fft, name, func)
51+
52+
def _restore_func(self, name, verbose=False):
53+
if name not in self._patched_functions:
54+
raise ValueError(f"{name} not an mkl_fft function.")
55+
try:
56+
val = self._restore_dict[name]
57+
except KeyError:
58+
if verbose:
59+
print(f"failed to restore {name}")
60+
return
61+
else:
62+
if verbose:
63+
print(f"found and restoring {name}...")
64+
setattr(np.fft, name, val)
65+
66+
def do_patch(self, verbose=False):
67+
with self._lock:
68+
local_count = getattr(self._tls, "local_count", 0)
69+
if self._patch_count == 0:
70+
if verbose:
71+
print(
72+
"Now patching NumPy FFT submodule with mkl_fft NumPy "
73+
"interface."
74+
)
75+
print(
76+
"Please direct bug reports to "
77+
"https://github.com/IntelPython/mkl_fft"
78+
)
79+
for f in self._patched_functions:
80+
self._register_func(f, getattr(_nfft, f))
81+
self._patch_count += 1
82+
self._tls.local_count = local_count + 1
83+
84+
def do_restore(self, verbose=False):
85+
with self._lock:
86+
local_count = getattr(self._tls, "local_count", 0)
87+
if local_count <= 0:
88+
if verbose:
89+
print(
90+
"Warning: restore_numpy_fft called more times than "
91+
"patch_numpy_fft in this thread."
92+
)
93+
return
94+
self._tls.local_count -= 1
95+
self._patch_count -= 1
96+
if self._patch_count == 0:
97+
if verbose:
98+
print("Now restoring original NumPy FFT submodule.")
99+
for name in tuple(self._restore_dict):
100+
self._restore_func(name, verbose=verbose)
101+
self._restore_dict.clear()
102+
103+
def is_patched(self):
104+
with self._lock:
105+
return self._patch_count > 0
106+
107+
108+
_patch = _GlobalPatch()
109+
110+
111+
def patch_numpy_fft(verbose=False):
112+
"""
113+
Patch NumPy's fft submodule with mkl_fft's numpy_interface.
114+
115+
Parameters
116+
----------
117+
verbose : bool, optional
118+
print message when starting the patching process.
119+
120+
Notes
121+
-----
122+
This function uses reference-counted semantics. Each call increments a
123+
global patch counter. Restoration requires a matching number of calls
124+
between `patch_numpy_fft` and `restore_numpy_fft`.
125+
126+
In multi-threaded programs, prefer the `mkl_fft` context manager.
127+
128+
"""
129+
_patch.do_patch(verbose=verbose)
130+
131+
132+
def restore_numpy_fft(verbose=False):
133+
"""
134+
Restore NumPy's fft submodule to its original implementations.
135+
136+
Parameters
137+
----------
138+
verbose : bool, optional
139+
print message when starting restoration process.
140+
141+
Notes
142+
-----
143+
This function uses reference-counted semantics. Each call decrements a
144+
global patch counter. Restoration requires a matching number of calls
145+
between `patch_numpy_fft` and `restore_numpy_fft`.
146+
147+
In multi-threaded programs, prefer the `mkl_fft` context manager.
148+
149+
"""
150+
_patch.do_restore(verbose=verbose)
151+
152+
153+
def is_patched():
154+
"""Return True if NumPy's fft submodule is currently patched by mkl_fft."""
155+
return _patch.is_patched()
156+
157+
158+
class mkl_fft(ContextDecorator):
159+
"""
160+
Context manager and decorator to temporarily patch NumPy fft submodule
161+
with MKL-based implementations.
162+
163+
Examples
164+
--------
165+
>>> import mkl_fft
166+
>>> mkl_fft.is_patched()
167+
# False
168+
169+
>>> with mkl_fft.mkl_fft(): # Enable mkl_fft in Numpy
170+
>>> print(mkl_fft.is_patched())
171+
# True
172+
173+
>>> mkl_fft.is_patched()
174+
# False
175+
176+
"""
177+
178+
def __enter__(self):
179+
patch_numpy_fft()
180+
return self
181+
182+
def __exit__(self, *exc):
183+
restore_numpy_fft()
184+
return False

mkl_fft/interfaces/_float_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/interfaces/_numpy_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/interfaces/_numpy_helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

mkl_fft/interfaces/numpy_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) 2017, Intel Corporation
32
#
43
# Redistribution and use in source and binary forms, with or without

0 commit comments

Comments
 (0)