1- from collections .abc import Callable , Hashable
2- from functools import lru_cache , partial
3- from itertools import tee
4- from typing import TypeVar
1+ from collections .abc import Hashable , Iterator
2+ from functools import lru_cache , partial , update_wrapper
3+ from threading import RLock , local
4+ from typing import Callable , Concatenate , Generic , ParamSpec , TypeVar
55
6- __all__ = ['memoized_meth' , 'memoized_generator' , 'CacheInstances' ]
7-
8-
9- class memoized_meth :
10- """
11- Decorator. Cache the return value of a class method.
126
13- Unlike ``functools.cache``, the return value of a given method invocation
14- will be cached on the instance whose method was invoked. All arguments
15- passed to a method decorated with memoize must be hashable.
7+ __all__ = ['memoized_meth' , 'memoized_generator' , 'CacheInstances' ]
168
17- If a memoized method is invoked directly on its class the result will not
18- be cached. Instead the method will be invoked like a static method: ::
199
20- class Obj:
21- @memoize
22- def add_to(self, arg):
23- return self + arg
24- Obj.add_to(1) # not enough arguments
25- Obj.add_to(1, 2) # returns 3, result is not cached
10+ # Type variables for memoized method decorators
11+ InstanceType = TypeVar ('InstanceType' , contravariant = True )
12+ ParamsType = ParamSpec ('ParamsType' )
13+ ReturnType = TypeVar ('ReturnType' , covariant = True )
2614
27- Adapted from: ::
2815
29- code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
16+ class memoized_meth (Generic [InstanceType , ParamsType , ReturnType ]):
17+ """
18+ Decorator for a cached instance method. There is one cache per thread stored
19+ on the object instance itself.
3020 """
3121
32- def __init__ (self , func ):
33- self .func = func
22+ def __init__ (self , meth : Callable [Concatenate [InstanceType , ParamsType ],
23+ ReturnType ]) -> None :
24+ self ._meth = meth
25+ self ._lock = RLock () # Lock to safely initialize the thread-local object
26+ update_wrapper (self , self ._meth )
3427
35- def __get__ (self , obj , objtype = None ):
36- if obj is None :
37- return self .func
28+ def __get__ (self , obj : InstanceType , cls : type [InstanceType ] | None = None ) \
29+ -> Callable [ParamsType , ReturnType ]:
30+ """
31+ Binds the memoized method to an instance.
32+ """
3833 return partial (self , obj )
3934
40- def __call__ (self , * args , ** kw ) :
41- if not isinstance ( args , Hashable ):
42- # Uncacheable, a list, for instance.
43- # Better to not cache than blow up .
44- return self . func ( * args )
45- obj = args [ 0 ]
35+ def _get_cache (self , obj : InstanceType ) -> dict [ Hashable , ReturnType ] :
36+ """
37+ Retrieves the thread-local cache for the given object instance, initializing
38+ it if necessary .
39+ """
40+ # Try-catch is theoretically faster on the happy path
4641 try :
47- cache = obj .__cache_meth
42+ # Attempt to access the cache directly
43+ return obj ._memoized_meth__local .cache
44+
45+ # If the cache doesn't exist, initialize it
4846 except AttributeError :
49- cache = obj .__cache_meth = {}
50- key = (self .func , args [1 :], frozenset (kw .items ()))
47+ with self ._lock :
48+ # Check again in case another thread initialized outside the lock
49+ if not hasattr (obj , '_memoized_cache' ):
50+ # Initialize the cache if it doesn't exist
51+ obj ._memoized_meth__local = local ()
52+ obj ._memoized_meth__local .cache = {}
53+
54+ # Return the cache
55+ return obj ._memoized_meth__local .cache
56+
57+ def __call__ (self , obj : InstanceType ,
58+ * args : ParamsType .args , ** kwargs : ParamsType .kwargs ) -> ReturnType :
59+ """
60+ Invokes the memoized method, caching the result if it hasn't been evaluated yet.
61+ """
62+ # If arguments are not hashable, just evaluate the method directly
63+ if not isinstance (args , Hashable ):
64+ return self ._meth (obj , * args , ** kwargs )
65+
66+ # Get the local cache for the object instance
67+ cache = self ._get_cache (obj )
68+ key = (self ._meth , args , frozenset (kwargs .items ()))
5169 try :
70+ # Try to retrieve the cached value
5271 res = cache [key ]
5372 except KeyError :
54- res = cache [key ] = self .func (* args , ** kw )
73+ # If not cached, compute the value
74+ res = cache [key ] = self ._meth (obj , * args , ** kwargs )
75+
5576 return res
5677
5778
58- class memoized_generator :
79+ # Describes the type of element yielded by a cached iterator
80+ YieldType = TypeVar ('YieldType' , covariant = True )
81+
5982
83+ class SafeTee (Iterator [YieldType ]):
6084 """
61- Decorator. Cache the return value of an instance generator method.
85+ A thread-safe version of `itertools.tee` that allows multiple iterators to safely
86+ share the same buffer.
87+
88+ In theory, this comes at a cost to performance of iterating elements that haven't
89+ yet been generated, as `itertools.tee` is implemented in C (i.e. is fast) but we
90+ need to buffer (and lock) in Python instead.
91+
92+ However, the lock is not needed for elements that have already been buffered,
93+ allowing for concurrent iteration after the generator is initially consumed.
6294 """
95+ def __init__ (self , source_iter : Iterator [YieldType ],
96+ buffer : list [YieldType ] = None , lock : RLock = None ) \
97+ -> None :
98+ # If no buffer/lock are provided, this is a parent iterator
99+ self ._source_iter = source_iter
100+ self ._buffer = buffer if buffer is not None else []
101+ self ._lock = lock if lock is not None else RLock ()
102+ self ._next = 0
103+
104+ def __iter__ (self ) -> Iterator [YieldType ]:
105+ return self
106+
107+ def __next__ (self ) -> YieldType :
108+ """
109+ Safely retrieves the buffer if available, or generates the next element
110+ from the source iterator if not.
111+ """
112+ # Retry concurrent element access until we can return a value
113+ while True :
114+ if self ._next < len (self ._buffer ):
115+ # If we have another buffered element, return it
116+ result = self ._buffer [self ._next ]
117+ self ._next += 1
118+
119+ return result
120+
121+ # Otherwise, we may need to generate a new element
122+ with self ._lock :
123+ if self ._next < len (self ._buffer ):
124+ # Another thread has already generated the next element; retry
125+ continue
126+
127+ # Generate the next element from the source iterator
128+ try :
129+ # Try to get the next element from the source iterator
130+ result = next (self ._source_iter )
131+ self ._buffer .append (result )
132+ self ._next += 1
133+ return result
134+ except StopIteration :
135+ # The source iterator has been exhausted
136+ raise
137+
138+ def __copy__ (self ) -> 'SafeTee' :
139+ return SafeTee (self ._source_iter , self ._buffer , self ._lock )
140+
141+ def tee (self ) -> Iterator [YieldType ]:
142+ """
143+ Creates a new iterator that shares the same buffer and lock.
144+ """
145+ return self .__copy__ ()
63146
64- def __init__ (self , func ):
65- self .func = func
66147
67- def __repr__ (self ):
68- """Return the function's docstring."""
69- return self .func .__doc__
148+ class memoized_generator (Generic [InstanceType , ParamsType , YieldType ]):
149+ """
150+ Decorator for a cached instance generator method. The initial call to the generator
151+ will block and return a thread-safe version of `itertools.tee` that allows for
152+ concurrent iteration.
153+ """
154+
155+ def __init__ (self , meth : Callable [Concatenate [InstanceType , ParamsType ],
156+ Iterator [YieldType ]]) -> None :
157+ self ._meth = meth
158+ self ._lock = RLock () # Lock for initial generator calls
159+ update_wrapper (self , self ._meth )
70160
71- def __get__ (self , obj , objtype = None ):
72- if obj is None :
73- return self .func
161+ def __get__ (self , obj : InstanceType , cls : type [InstanceType ] | None = None ) \
162+ -> Callable [ParamsType , Iterator [YieldType ]]:
163+ """
164+ Binds the memoized method to an instance.
165+ """
74166 return partial (self , obj )
75167
76- def __call__ (self , * args , ** kwargs ) :
77- if not isinstance ( args , Hashable ):
78- # Uncacheable, a list, for instance.
79- # Better to not cache than blow up .
80- return self . func ( * args )
81- obj = args [ 0 ]
168+ def _get_cache (self , obj : InstanceType ) -> dict [ Hashable , SafeTee [ YieldType ]] :
169+ """
170+ Retrieves the generator cache for the given object instance, initializing
171+ it if necessary .
172+ """
173+ # Try-catch is theoretically faster on the happy path
82174 try :
83- cache = obj .__cache_gen
175+ # Attempt to access the cache directly
176+ return obj ._memoized_generator__cache
177+
178+ # If the cache doesn't exist, initialize it
84179 except AttributeError :
85- cache = obj .__cache_gen = {}
86- key = (self .func , args [1 :], frozenset (kwargs .items ()))
87- it = cache [key ] if key in cache else self .func (* args , ** kwargs )
88- cache [key ], result = tee (it )
89- return result
180+ with self ._lock :
181+ # Check again in case another thread initialized outside the lock
182+ if not hasattr (obj , '_memoized_cache' ):
183+ # Initialize the cache if it doesn't exist
184+ obj ._memoized_generator__cache = {}
185+
186+ # Return the cache
187+ return obj ._memoized_generator__cache
188+
189+ def __call__ (self , obj : InstanceType ,
190+ * args : ParamsType .args , ** kwargs : ParamsType .kwargs ) \
191+ -> Iterator [YieldType ]:
192+ """
193+ Invokes the memoized generator, caching a SafeTee if it hasn't been created yet.
194+ """
195+ # If arguments are not hashable, just evaluate the method directly
196+ if not isinstance (args , Hashable ):
197+ return self ._meth (obj , * args , ** kwargs )
198+
199+ # Get the local cache for the object instance
200+ cache = self ._get_cache (obj )
201+ key = (self ._meth , args , frozenset (kwargs .items ()))
202+ try :
203+ # Try to retrieve the cached value
204+ res = cache [key ]
205+ except KeyError :
206+ # If not cached, compute the value
207+ source_iter = self ._meth (obj , * args , ** kwargs )
208+ res = cache [key ] = SafeTee (source_iter )
209+
210+ return res .tee ()
90211
91212
92213# Describes the type of a subclass of CacheInstances
93- InstanceType = TypeVar ('InstanceType' , bound = 'CacheInstances' , covariant = True )
214+ CachedInstanceType = TypeVar ('CachedInstanceType' ,
215+ bound = 'CacheInstances' , covariant = True )
94216
95217
96218class CacheInstancesMeta (type ):
@@ -100,14 +222,14 @@ class CacheInstancesMeta(type):
100222
101223 _cached_types : set [type ['CacheInstances' ]] = set ()
102224
103- def __init__ (cls : type [InstanceType ], * args ) -> None : # type: ignore
225+ def __init__ (cls : type [CachedInstanceType ], * args ) -> None : # type: ignore
104226 super ().__init__ (* args )
105227
106228 # Register the cached type
107229 CacheInstancesMeta ._cached_types .add (cls )
108230
109- def __call__ (cls : type [InstanceType ], # type: ignore
110- * args , ** kwargs ) -> InstanceType :
231+ def __call__ (cls : type [CachedInstanceType ], # type: ignore
232+ * args , ** kwargs ) -> CachedInstanceType :
111233 if cls ._instance_cache is None :
112234 maxsize = cls ._instance_cache_size
113235 cls ._instance_cache = lru_cache (maxsize = maxsize )(super ().__call__ )
0 commit comments