Skip to content

Commit e182032

Browse files
committed
misc: Memoized methods rewrite
1 parent 9d21e0d commit e182032

1 file changed

Lines changed: 186 additions & 64 deletions

File tree

devito/tools/memoization.py

Lines changed: 186 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,218 @@
1-
from collections.abc import Callable, Hashable
2-
from functools import lru_cache, partial
3-
from itertools import tee
4-
from typing import TypeVar
1+
from collections.abc import Hashable, Iterator
2+
from functools import lru_cache, partial, update_wrapper
3+
from threading import RLock, local
4+
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar
55

6-
__all__ = ['memoized_meth', 'memoized_generator', 'CacheInstances']
7-
8-
9-
class memoized_meth:
10-
"""
11-
Decorator. Cache the return value of a class method.
126

13-
Unlike ``functools.cache``, the return value of a given method invocation
14-
will be cached on the instance whose method was invoked. All arguments
15-
passed to a method decorated with memoize must be hashable.
7+
__all__ = ['memoized_meth', 'memoized_generator', 'CacheInstances']
168

17-
If a memoized method is invoked directly on its class the result will not
18-
be cached. Instead the method will be invoked like a static method: ::
199

20-
class Obj:
21-
@memoize
22-
def add_to(self, arg):
23-
return self + arg
24-
Obj.add_to(1) # not enough arguments
25-
Obj.add_to(1, 2) # returns 3, result is not cached
10+
# Type variables for memoized method decorators
11+
InstanceType = TypeVar('InstanceType', contravariant=True)
12+
ParamsType = ParamSpec('ParamsType')
13+
ReturnType = TypeVar('ReturnType', covariant=True)
2614

27-
Adapted from: ::
2815

29-
code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
16+
class memoized_meth(Generic[InstanceType, ParamsType, ReturnType]):
17+
"""
18+
Decorator for a cached instance method. There is one cache per thread stored
19+
on the object instance itself.
3020
"""
3121

32-
def __init__(self, func):
33-
self.func = func
22+
def __init__(self, meth: Callable[Concatenate[InstanceType, ParamsType],
23+
ReturnType]) -> None:
24+
self._meth = meth
25+
self._lock = RLock() # Lock to safely initialize the thread-local object
26+
update_wrapper(self, self._meth)
3427

35-
def __get__(self, obj, objtype=None):
36-
if obj is None:
37-
return self.func
28+
def __get__(self, obj: InstanceType, cls: type[InstanceType] | None = None) \
29+
-> Callable[ParamsType, ReturnType]:
30+
"""
31+
Binds the memoized method to an instance.
32+
"""
3833
return partial(self, obj)
3934

40-
def __call__(self, *args, **kw):
41-
if not isinstance(args, Hashable):
42-
# Uncacheable, a list, for instance.
43-
# Better to not cache than blow up.
44-
return self.func(*args)
45-
obj = args[0]
35+
def _get_cache(self, obj: InstanceType) -> dict[Hashable, ReturnType]:
36+
"""
37+
Retrieves the thread-local cache for the given object instance, initializing
38+
it if necessary.
39+
"""
40+
# Try-catch is theoretically faster on the happy path
4641
try:
47-
cache = obj.__cache_meth
42+
# Attempt to access the cache directly
43+
return obj._memoized_meth__local.cache
44+
45+
# If the cache doesn't exist, initialize it
4846
except AttributeError:
49-
cache = obj.__cache_meth = {}
50-
key = (self.func, args[1:], frozenset(kw.items()))
47+
with self._lock:
48+
# Check again in case another thread initialized outside the lock
49+
if not hasattr(obj, '_memoized_cache'):
50+
# Initialize the cache if it doesn't exist
51+
obj._memoized_meth__local = local()
52+
obj._memoized_meth__local.cache = {}
53+
54+
# Return the cache
55+
return obj._memoized_meth__local.cache
56+
57+
def __call__(self, obj: InstanceType,
58+
*args: ParamsType.args, **kwargs: ParamsType.kwargs) -> ReturnType:
59+
"""
60+
Invokes the memoized method, caching the result if it hasn't been evaluated yet.
61+
"""
62+
# If arguments are not hashable, just evaluate the method directly
63+
if not isinstance(args, Hashable):
64+
return self._meth(obj, *args, **kwargs)
65+
66+
# Get the local cache for the object instance
67+
cache = self._get_cache(obj)
68+
key = (self._meth, args, frozenset(kwargs.items()))
5169
try:
70+
# Try to retrieve the cached value
5271
res = cache[key]
5372
except KeyError:
54-
res = cache[key] = self.func(*args, **kw)
73+
# If not cached, compute the value
74+
res = cache[key] = self._meth(obj, *args, **kwargs)
75+
5576
return res
5677

5778

58-
class memoized_generator:
79+
# Describes the type of element yielded by a cached iterator
80+
YieldType = TypeVar('YieldType', covariant=True)
81+
5982

83+
class SafeTee(Iterator[YieldType]):
6084
"""
61-
Decorator. Cache the return value of an instance generator method.
85+
A thread-safe version of `itertools.tee` that allows multiple iterators to safely
86+
share the same buffer.
87+
88+
In theory, this comes at a cost to performance of iterating elements that haven't
89+
yet been generated, as `itertools.tee` is implemented in C (i.e. is fast) but we
90+
need to buffer (and lock) in Python instead.
91+
92+
However, the lock is not needed for elements that have already been buffered,
93+
allowing for concurrent iteration after the generator is initially consumed.
6294
"""
95+
def __init__(self, source_iter: Iterator[YieldType],
96+
buffer: list[YieldType] = None, lock: RLock = None) \
97+
-> None:
98+
# If no buffer/lock are provided, this is a parent iterator
99+
self._source_iter = source_iter
100+
self._buffer = buffer if buffer is not None else []
101+
self._lock = lock if lock is not None else RLock()
102+
self._next = 0
103+
104+
def __iter__(self) -> Iterator[YieldType]:
105+
return self
106+
107+
def __next__(self) -> YieldType:
108+
"""
109+
Safely retrieves the buffer if available, or generates the next element
110+
from the source iterator if not.
111+
"""
112+
# Retry concurrent element access until we can return a value
113+
while True:
114+
if self._next < len(self._buffer):
115+
# If we have another buffered element, return it
116+
result = self._buffer[self._next]
117+
self._next += 1
118+
119+
return result
120+
121+
# Otherwise, we may need to generate a new element
122+
with self._lock:
123+
if self._next < len(self._buffer):
124+
# Another thread has already generated the next element; retry
125+
continue
126+
127+
# Generate the next element from the source iterator
128+
try:
129+
# Try to get the next element from the source iterator
130+
result = next(self._source_iter)
131+
self._buffer.append(result)
132+
self._next += 1
133+
return result
134+
except StopIteration:
135+
# The source iterator has been exhausted
136+
raise
137+
138+
def __copy__(self) -> 'SafeTee':
139+
return SafeTee(self._source_iter, self._buffer, self._lock)
140+
141+
def tee(self) -> Iterator[YieldType]:
142+
"""
143+
Creates a new iterator that shares the same buffer and lock.
144+
"""
145+
return self.__copy__()
63146

64-
def __init__(self, func):
65-
self.func = func
66147

67-
def __repr__(self):
68-
"""Return the function's docstring."""
69-
return self.func.__doc__
148+
class memoized_generator(Generic[InstanceType, ParamsType, YieldType]):
149+
"""
150+
Decorator for a cached instance generator method. The initial call to the generator
151+
will block and return a thread-safe version of `itertools.tee` that allows for
152+
concurrent iteration.
153+
"""
154+
155+
def __init__(self, meth: Callable[Concatenate[InstanceType, ParamsType],
156+
Iterator[YieldType]]) -> None:
157+
self._meth = meth
158+
self._lock = RLock() # Lock for initial generator calls
159+
update_wrapper(self, self._meth)
70160

71-
def __get__(self, obj, objtype=None):
72-
if obj is None:
73-
return self.func
161+
def __get__(self, obj: InstanceType, cls: type[InstanceType] | None = None) \
162+
-> Callable[ParamsType, Iterator[YieldType]]:
163+
"""
164+
Binds the memoized method to an instance.
165+
"""
74166
return partial(self, obj)
75167

76-
def __call__(self, *args, **kwargs):
77-
if not isinstance(args, Hashable):
78-
# Uncacheable, a list, for instance.
79-
# Better to not cache than blow up.
80-
return self.func(*args)
81-
obj = args[0]
168+
def _get_cache(self, obj: InstanceType) -> dict[Hashable, SafeTee[YieldType]]:
169+
"""
170+
Retrieves the generator cache for the given object instance, initializing
171+
it if necessary.
172+
"""
173+
# Try-catch is theoretically faster on the happy path
82174
try:
83-
cache = obj.__cache_gen
175+
# Attempt to access the cache directly
176+
return obj._memoized_generator__cache
177+
178+
# If the cache doesn't exist, initialize it
84179
except AttributeError:
85-
cache = obj.__cache_gen = {}
86-
key = (self.func, args[1:], frozenset(kwargs.items()))
87-
it = cache[key] if key in cache else self.func(*args, **kwargs)
88-
cache[key], result = tee(it)
89-
return result
180+
with self._lock:
181+
# Check again in case another thread initialized outside the lock
182+
if not hasattr(obj, '_memoized_cache'):
183+
# Initialize the cache if it doesn't exist
184+
obj._memoized_generator__cache = {}
185+
186+
# Return the cache
187+
return obj._memoized_generator__cache
188+
189+
def __call__(self, obj: InstanceType,
190+
*args: ParamsType.args, **kwargs: ParamsType.kwargs) \
191+
-> Iterator[YieldType]:
192+
"""
193+
Invokes the memoized generator, caching a SafeTee if it hasn't been created yet.
194+
"""
195+
# If arguments are not hashable, just evaluate the method directly
196+
if not isinstance(args, Hashable):
197+
return self._meth(obj, *args, **kwargs)
198+
199+
# Get the local cache for the object instance
200+
cache = self._get_cache(obj)
201+
key = (self._meth, args, frozenset(kwargs.items()))
202+
try:
203+
# Try to retrieve the cached value
204+
res = cache[key]
205+
except KeyError:
206+
# If not cached, compute the value
207+
source_iter = self._meth(obj, *args, **kwargs)
208+
res = cache[key] = SafeTee(source_iter)
209+
210+
return res.tee()
90211

91212

92213
# Describes the type of a subclass of CacheInstances
93-
InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True)
214+
CachedInstanceType = TypeVar('CachedInstanceType',
215+
bound='CacheInstances', covariant=True)
94216

95217

96218
class CacheInstancesMeta(type):
@@ -100,14 +222,14 @@ class CacheInstancesMeta(type):
100222

101223
_cached_types: set[type['CacheInstances']] = set()
102224

103-
def __init__(cls: type[InstanceType], *args) -> None: # type: ignore
225+
def __init__(cls: type[CachedInstanceType], *args) -> None: # type: ignore
104226
super().__init__(*args)
105227

106228
# Register the cached type
107229
CacheInstancesMeta._cached_types.add(cls)
108230

109-
def __call__(cls: type[InstanceType], # type: ignore
110-
*args, **kwargs) -> InstanceType:
231+
def __call__(cls: type[CachedInstanceType], # type: ignore
232+
*args, **kwargs) -> CachedInstanceType:
111233
if cls._instance_cache is None:
112234
maxsize = cls._instance_cache_size
113235
cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__)

0 commit comments

Comments
 (0)