@@ -114,42 +114,47 @@ cdef class _patch_impl:
114114 free(self .functions[i].signature)
115115 free(self .functions)
116116
117- def do_patch (self ):
117+ cdef int _replace_loop(
118+ self ,
119+ object func,
120+ cnp.PyUFuncGenericFunction function,
121+ ) except - 1 :
118122 cdef int res
119123 cdef cnp.PyUFuncGenericFunction temp
120- cdef cnp.PyUFuncGenericFunction function
121124 cdef int * signature
122125
126+ np_umath = getattr (np, func[0 ])
127+ index = self .functions_dict[func]
128+ signature = self .functions[index].signature
129+ res = cnp.PyUFunc_ReplaceLoopBySignature(
130+ < cnp.ufunc> np_umath, function, signature, & temp
131+ )
132+ return res
133+
134+ def do_patch (self ):
135+ cdef int index
136+
123137 for func in self .functions_dict:
124- np_umath = getattr (np, func[0 ])
125138 index = self .functions_dict[func]
126- function = self .functions[index].patch_function
127- signature = self .functions[index].signature
128- res = cnp.PyUFunc_ReplaceLoopBySignature(
129- < cnp.ufunc> np_umath, function, signature, & temp
130- )
131- if res != 0 :
139+ if self ._replace_loop(
140+ func, self .functions[index].patch_function
141+ ) != 0 :
132142 raise RuntimeError (
133- f" Failed to patch {func[0]} with signature {func[1]}"
143+ f" Failed to patch {func[0]} with signature {func[1]}. "
144+ " NumPy may be partially restored or in an invalid state."
134145 )
135146
136147 def do_unpatch (self ):
137- cdef int res
138- cdef cnp.PyUFuncGenericFunction temp
139- cdef cnp.PyUFuncGenericFunction function
140- cdef int * signature
148+ cdef int index
141149
142150 for func in self .functions_dict:
143- np_umath = getattr (np, func[0 ])
144151 index = self .functions_dict[func]
145- function = self .functions[index].original_function
146- signature = self .functions[index].signature
147- res = cnp.PyUFunc_ReplaceLoopBySignature(
148- < cnp.ufunc> np_umath, function, signature, & temp
149- )
150- if res != 0 :
152+ if self ._replace_loop(
153+ func, self .functions[index].original_function
154+ ) != 0 :
151155 raise RuntimeError (
152- f" Failed to restore {func[0]} with signature {func[1]}"
156+ f" Failed to restore {func[0]} with signature {func[1]}. "
157+ " NumPy may be partially restored or in an invalid state."
153158 )
154159
155160
@@ -190,13 +195,16 @@ class _GlobalPatch:
190195 " patch_numpy_umath in this thread."
191196 )
192197 return
193- self ._tls.local_count -= 1
194- self ._patch_count -= 1
195- if self ._patch_count == 0 :
198+
199+ next_patch_count = self ._patch_count - 1
200+ if next_patch_count == 0 :
196201 if verbose:
197202 print (" Now restoring original NumPy loops." )
198203 self ._patcher.do_unpatch()
199204
205+ self ._tls.local_count -= 1
206+ self ._patch_count = next_patch_count
207+
200208 def is_patched (self ):
201209 with self ._lock:
202210 return self ._patch_count > 0
0 commit comments