Skip to content

Commit 558ee9d

Browse files
committed
improve patch safety when failing to patch or restore numpy functions
1 parent 0ef0e87 commit 558ee9d

1 file changed

Lines changed: 33 additions & 25 deletions

File tree

mkl_umath/src/_patch_numpy.pyx

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -114,42 +114,47 @@ cdef class _patch_impl:
114114
free(self.functions[i].signature)
115115
free(self.functions)
116116

117-
def do_patch(self):
117+
cdef int _replace_loop(
118+
self,
119+
object func,
120+
cnp.PyUFuncGenericFunction function,
121+
) except -1:
118122
cdef int res
119123
cdef cnp.PyUFuncGenericFunction temp
120-
cdef cnp.PyUFuncGenericFunction function
121124
cdef int* signature
122125

126+
np_umath = getattr(np, func[0])
127+
index = self.functions_dict[func]
128+
signature = self.functions[index].signature
129+
res = cnp.PyUFunc_ReplaceLoopBySignature(
130+
<cnp.ufunc>np_umath, function, signature, &temp
131+
)
132+
return res
133+
134+
def do_patch(self):
135+
cdef int index
136+
123137
for func in self.functions_dict:
124-
np_umath = getattr(np, func[0])
125138
index = self.functions_dict[func]
126-
function = self.functions[index].patch_function
127-
signature = self.functions[index].signature
128-
res = cnp.PyUFunc_ReplaceLoopBySignature(
129-
<cnp.ufunc>np_umath, function, signature, &temp
130-
)
131-
if res != 0:
139+
if self._replace_loop(
140+
func, self.functions[index].patch_function
141+
) != 0:
132142
raise RuntimeError(
133-
f"Failed to patch {func[0]} with signature {func[1]}"
143+
f"Failed to patch {func[0]} with signature {func[1]}. "
144+
"NumPy may be partially restored or in an invalid state."
134145
)
135146

136147
def do_unpatch(self):
137-
cdef int res
138-
cdef cnp.PyUFuncGenericFunction temp
139-
cdef cnp.PyUFuncGenericFunction function
140-
cdef int* signature
148+
cdef int index
141149

142150
for func in self.functions_dict:
143-
np_umath = getattr(np, func[0])
144151
index = self.functions_dict[func]
145-
function = self.functions[index].original_function
146-
signature = self.functions[index].signature
147-
res = cnp.PyUFunc_ReplaceLoopBySignature(
148-
<cnp.ufunc>np_umath, function, signature, &temp
149-
)
150-
if res != 0:
152+
if self._replace_loop(
153+
func, self.functions[index].original_function
154+
) != 0:
151155
raise RuntimeError(
152-
f"Failed to restore {func[0]} with signature {func[1]}"
156+
f"Failed to restore {func[0]} with signature {func[1]}. "
157+
"NumPy may be partially restored or in an invalid state."
153158
)
154159

155160

@@ -190,13 +195,16 @@ class _GlobalPatch:
190195
"patch_numpy_umath in this thread."
191196
)
192197
return
193-
self._tls.local_count -= 1
194-
self._patch_count -= 1
195-
if self._patch_count == 0:
198+
199+
next_patch_count = self._patch_count - 1
200+
if next_patch_count == 0:
196201
if verbose:
197202
print("Now restoring original NumPy loops.")
198203
self._patcher.do_unpatch()
199204

205+
self._tls.local_count -= 1
206+
self._patch_count = next_patch_count
207+
200208
def is_patched(self):
201209
with self._lock:
202210
return self._patch_count > 0

0 commit comments

Comments
 (0)