2727"""Define functions for patching NumPy with MKL-based NumPy interface."""
2828
2929from contextlib import ContextDecorator
30- from threading import Lock
30+ from threading import Lock , local
3131
3232import numpy as np
3333
@@ -41,6 +41,7 @@ def __init__(self):
4141 self ._restore_dict = {}
4242 # make _patched_functions a tuple (immutable)
4343 self ._patched_functions = tuple (_nfft .__all__ )
44+ self ._tls = local ()
4445
4546 def _register_func (self , name , func ):
4647 if name not in self ._patched_functions :
@@ -65,20 +66,34 @@ def _restore_func(self, name, verbose=False):
6566
6667 def do_patch (self , verbose = False ):
6768 with self ._lock :
69+ local_count = getattr (self ._tls , "local_count" , 0 )
6870 if self ._patch_count == 0 :
6971 if verbose :
70- print ("Now patching NumPy FFT submodule with mkl_fft NumPy interface." )
7172 print (
72- "Please direct bug reports to https://github.com/IntelPython/mkl_fft"
73+ "Now patching NumPy FFT submodule with mkl_fft NumPy "
74+ "interface."
75+ )
76+ print (
77+ "Please direct bug reports to "
78+ "https://github.com/IntelPython/mkl_fft"
7379 )
7480 for f in self ._patched_functions :
7581 self ._register_func (f , getattr (_nfft , f ))
7682 self ._patch_count += 1
83+ self ._tls .local_count = local_count + 1
7784
7885 def do_restore (self , verbose = False ):
7986 with self ._lock :
80- if self ._patch_count > 0 :
81- self ._patch_count -= 1
87+ local_count = getattr (self ._tls , "local_count" , 0 )
88+ if local_count <= 0 :
89+ if verbose :
90+ print (
91+ "Warning: restore_numpy_fft called more times than "
92+ "patch_numpy_fft in this thread."
93+ )
94+ return
95+ self ._tls .local_count -= 1
96+ self ._patch_count -= 1
8297 if self ._patch_count == 0 :
8398 if verbose :
8499 print ("Now restoring original NumPy FFT submodule." )
@@ -103,13 +118,22 @@ def __exit__(self, *exc):
103118
104119
105120def patch_numpy_fft (verbose = False ):
106- """Patch NumPy's fft submodule with mkl_fft's numpy_interface.
121+ """
122+ Patch NumPy's fft submodule with mkl_fft's numpy_interface.
107123
108124 Parameters
109125 ----------
110126 verbose : bool, optional
111127 print message when starting the patching process.
112128
129+ Notes
130+ -----
131+ This function uses reference-counted semantics. Each call increments a
132+ global patch counter. Restoration requires a matching number of calls
133+ between `patch_numpy_fft` and `restore_numpy_fft`.
134+
135+ In multi-threaded programs, prefer the `mkl_fft` context manager.
136+
113137 """
114138 _patch .do_patch (verbose = verbose )
115139
@@ -123,6 +147,14 @@ def restore_numpy_fft(verbose=False):
123147 verbose : bool, optional
124148 print message when starting restoration process.
125149
150+ Notes
151+ -----
152+ This function uses reference-counted semantics. Each call decrements a
153+ global patch counter. Restoration requires a matching number of calls
154+ between `patch_numpy_fft` and `restore_numpy_fft`.
155+
156+ In multi-threaded programs, prefer the `mkl_fft` context manager.
157+
126158 """
127159 _patch .do_restore (verbose = verbose )
128160
0 commit comments