@@ -111,22 +111,18 @@ def _do_rollback(self) -> None:
111111 self .get ().rollback ()
112112
113113
114- class ThreadLocalConnectionPool (_TransactionManagementMixin ):
114+ class _ThreadLocalBase (_TransactionManagementMixin ):
115115 def __init__ (
116116 self ,
117117 connection_factory : t .Callable [[], t .Any ],
118- shared_connection : bool = False ,
119118 cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
120119 ):
121120 self ._connection_factory = connection_factory
122- self ._thread_connections : t .Dict [t .Hashable , t .Any ] = {}
123121 self ._thread_cursors : t .Dict [t .Hashable , t .Any ] = {}
124122 self ._thread_transactions : t .Set [t .Hashable ] = set ()
125123 self ._thread_attributes : t .Dict [t .Hashable , t .Dict [str , t .Any ]] = defaultdict (dict )
126- self ._thread_connections_lock = Lock ()
127124 self ._thread_cursors_lock = Lock ()
128125 self ._thread_transactions_lock = Lock ()
129- self ._shared_connection = shared_connection
130126 self ._cursor_init = cursor_init
131127
132128 def get_cursor (self ) -> t .Any :
@@ -138,13 +134,6 @@ def get_cursor(self) -> t.Any:
138134 self ._cursor_init (self ._thread_cursors [thread_id ])
139135 return self ._thread_cursors [thread_id ]
140136
141- def get (self ) -> t .Any :
142- thread_id = get_ident ()
143- with self ._thread_connections_lock :
144- if thread_id not in self ._thread_connections :
145- self ._thread_connections [thread_id ] = self ._connection_factory ()
146- return self ._thread_connections [thread_id ]
147-
148137 def get_attribute (self , key : str ) -> t .Optional [t .Any ]:
149138 thread_id = get_ident ()
150139 return self ._thread_attributes [thread_id ].get (key )
@@ -178,6 +167,28 @@ def close_cursor(self) -> None:
178167 _try_close (self ._thread_cursors [thread_id ], "cursor" )
179168 self ._thread_cursors .pop (thread_id )
180169
170+ def _discard_transaction (self , thread_id : t .Hashable ) -> None :
171+ with self ._thread_transactions_lock :
172+ self ._thread_transactions .discard (thread_id )
173+
174+
175+ class ThreadLocalConnectionPool (_ThreadLocalBase ):
176+ def __init__ (
177+ self ,
178+ connection_factory : t .Callable [[], t .Any ],
179+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
180+ ):
181+ super ().__init__ (connection_factory , cursor_init )
182+ self ._thread_connections : t .Dict [t .Hashable , t .Any ] = {}
183+ self ._thread_connections_lock = Lock ()
184+
185+ def get (self ) -> t .Any :
186+ thread_id = get_ident ()
187+ with self ._thread_connections_lock :
188+ if thread_id not in self ._thread_connections :
189+ self ._thread_connections [thread_id ] = self ._connection_factory ()
190+ return self ._thread_connections [thread_id ]
191+
181192 def close (self ) -> None :
182193 thread_id = get_ident ()
183194 with self ._thread_cursors_lock , self ._thread_connections_lock :
@@ -189,23 +200,55 @@ def close(self) -> None:
189200 self ._thread_attributes .pop (thread_id , None )
190201
191202 def close_all (self , exclude_calling_thread : bool = False ) -> None :
192- if exclude_calling_thread and self ._shared_connection :
193- return
194-
195203 calling_thread_id = get_ident ()
196204 with self ._thread_cursors_lock , self ._thread_connections_lock :
197205 for thread_id , connection in self ._thread_connections .copy ().items ():
198206 if not exclude_calling_thread or thread_id != calling_thread_id :
199- # NOTE: the access to the connection instance itself is not thread-safe here.
200207 _try_close (connection , "connection" )
201208 self ._thread_connections .pop (thread_id )
202209 self ._thread_cursors .pop (thread_id , None )
203210 self ._discard_transaction (thread_id )
204211 self ._thread_attributes .pop (thread_id , None )
205212
206- def _discard_transaction (self , thread_id : t .Hashable ) -> None :
207- with self ._thread_transactions_lock :
208- self ._thread_transactions .discard (thread_id )
213+
214+ class ThreadLocalSharedConnectionPool (_ThreadLocalBase ):
215+ def __init__ (
216+ self ,
217+ connection_factory : t .Callable [[], t .Any ],
218+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
219+ ):
220+ super ().__init__ (connection_factory , cursor_init )
221+ self ._connection : t .Optional [t .Any ] = None
222+ self ._connection_lock = Lock ()
223+
224+ def get (self ) -> t .Any :
225+ with self ._connection_lock :
226+ if self ._connection is None :
227+ self ._connection = self ._connection_factory ()
228+ return self ._connection
229+
230+ def close (self ) -> None :
231+ thread_id = get_ident ()
232+ with self ._thread_cursors_lock , self ._connection_lock :
233+ if thread_id in self ._thread_cursors :
234+ _try_close (self ._thread_cursors [thread_id ], "cursor" )
235+ self ._thread_cursors .pop (thread_id )
236+ self ._discard_transaction (thread_id )
237+ self ._thread_attributes .pop (thread_id , None )
238+
239+ def close_all (self , exclude_calling_thread : bool = False ) -> None :
240+ calling_thread_id = get_ident ()
241+ with self ._thread_cursors_lock , self ._connection_lock :
242+ for thread_id , cursor in self ._thread_cursors .copy ().items ():
243+ if not exclude_calling_thread or thread_id != calling_thread_id :
244+ _try_close (cursor , "cursor" )
245+ self ._thread_cursors .pop (thread_id )
246+ self ._discard_transaction (thread_id )
247+ self ._thread_attributes .pop (thread_id , None )
248+
249+ if not exclude_calling_thread :
250+ _try_close (self ._connection , "connection" )
251+ self ._connection = None
209252
210253
211254class SingletonConnectionPool (_TransactionManagementMixin ):
@@ -277,13 +320,14 @@ def create_connection_pool(
277320 shared_connection : bool = False ,
278321 cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
279322) -> ConnectionPool :
280- return (
281- ThreadLocalConnectionPool (
282- connection_factory , shared_connection = shared_connection , cursor_init = cursor_init
283- )
323+ pool_class = (
324+ ThreadLocalSharedConnectionPool
325+ if multithreaded and shared_connection
326+ else ThreadLocalConnectionPool
284327 if multithreaded
285- else SingletonConnectionPool ( connection_factory , cursor_init = cursor_init )
328+ else SingletonConnectionPool
286329 )
330+ return pool_class (connection_factory , cursor_init = cursor_init )
287331
288332
289333def _try_close (closeable : t .Any , kind : str ) -> None :
0 commit comments