Skip to content

Commit 14d3705

Browse files
committed
use thread-local storage for bookkeeping patch calls per thread
1 parent 23edc89 commit 14d3705

1 file changed

Lines changed: 38 additions & 6 deletions

File tree

mkl_fft/_patch_numpy.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"""Define functions for patching NumPy with MKL-based NumPy interface."""
2828

2929
from contextlib import ContextDecorator
30-
from threading import Lock
30+
from threading import Lock, local
3131

3232
import numpy as np
3333

@@ -41,6 +41,7 @@ def __init__(self):
4141
self._restore_dict = {}
4242
# make _patched_functions a tuple (immutable)
4343
self._patched_functions = tuple(_nfft.__all__)
44+
self._tls = local()
4445

4546
def _register_func(self, name, func):
4647
if name not in self._patched_functions:
@@ -65,20 +66,34 @@ def _restore_func(self, name, verbose=False):
6566

6667
def do_patch(self, verbose=False):
6768
with self._lock:
69+
local_count = getattr(self._tls, "local_count", 0)
6870
if self._patch_count == 0:
6971
if verbose:
70-
print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.")
7172
print(
72-
"Please direct bug reports to https://github.com/IntelPython/mkl_fft"
73+
"Now patching NumPy FFT submodule with mkl_fft NumPy "
74+
"interface."
75+
)
76+
print(
77+
"Please direct bug reports to "
78+
"https://github.com/IntelPython/mkl_fft"
7379
)
7480
for f in self._patched_functions:
7581
self._register_func(f, getattr(_nfft, f))
7682
self._patch_count += 1
83+
self._tls.local_count = local_count + 1
7784

7885
def do_restore(self, verbose=False):
7986
with self._lock:
80-
if self._patch_count > 0:
81-
self._patch_count -= 1
87+
local_count = getattr(self._tls, "local_count", 0)
88+
if local_count <= 0:
89+
if verbose:
90+
print(
91+
"Warning: restore_numpy_fft called more times than "
92+
"patch_numpy_fft in this thread."
93+
)
94+
return
95+
self._tls.local_count -= 1
96+
self._patch_count -= 1
8297
if self._patch_count == 0:
8398
if verbose:
8499
print("Now restoring original NumPy FFT submodule.")
@@ -103,13 +118,22 @@ def __exit__(self, *exc):
103118

104119

105120
def patch_numpy_fft(verbose=False):
106-
"""Patch NumPy's fft submodule with mkl_fft's numpy_interface.
121+
"""
122+
Patch NumPy's fft submodule with mkl_fft's numpy_interface.
107123
108124
Parameters
109125
----------
110126
verbose : bool, optional
111127
print message when starting the patching process.
112128
129+
Notes
130+
-----
131+
This function uses reference-counted semantics. Each call increments a
132+
global patch counter. Restoration requires a matching number of calls
133+
between `patch_numpy_fft` and `restore_numpy_fft`.
134+
135+
In multi-threaded programs, prefer the `mkl_fft` context manager.
136+
113137
"""
114138
_patch.do_patch(verbose=verbose)
115139

@@ -123,6 +147,14 @@ def restore_numpy_fft(verbose=False):
123147
verbose : bool, optional
124148
print message when starting restoration process.
125149
150+
Notes
151+
-----
152+
This function uses reference-counted semantics. Each call decrements a
153+
global patch counter. Restoration requires a matching number of calls
154+
between `patch_numpy_fft` and `restore_numpy_fft`.
155+
156+
In multi-threaded programs, prefer the `mkl_fft` context manager.
157+
126158
"""
127159
_patch.do_restore(verbose=verbose)
128160

0 commit comments

Comments
 (0)