Skip to content

Commit 920108f

Browse files
committed
Introduce a thread local pool with a shared connection
1 parent 25a1121 commit 920108f

4 files changed

Lines changed: 165 additions & 36 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import typing as t
1010
from enum import Enum
11-
from functools import partial, lru_cache
11+
from functools import partial
1212

1313
import pydantic
1414
from pydantic import Field
@@ -97,21 +97,13 @@ def is_forbidden_for_state_sync(self) -> bool:
9797
@property
9898
def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
9999
"""A function that is called to return a connection object for the given Engine Adapter"""
100-
factory = partial(
100+
return partial(
101101
self._connection_factory,
102102
**{
103103
**self._static_connection_kwargs,
104104
**{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
105105
},
106106
)
107-
if self.shared_connection:
108-
# Make sure that a single connection is created and returned
109-
@lru_cache
110-
def _cached_connection() -> t.Any:
111-
return factory()
112-
113-
return _cached_connection
114-
return factory
115107

116108
def connection_validator(self) -> t.Callable[[], None]:
117109
"""A function that validates the connection configuration"""

sqlmesh/utils/connection_pool.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,18 @@ def _do_rollback(self) -> None:
111111
self.get().rollback()
112112

113113

114-
class ThreadLocalConnectionPool(_TransactionManagementMixin):
114+
class _ThreadLocalBase(_TransactionManagementMixin):
115115
def __init__(
116116
self,
117117
connection_factory: t.Callable[[], t.Any],
118-
shared_connection: bool = False,
119118
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
120119
):
121120
self._connection_factory = connection_factory
122-
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
123121
self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
124122
self._thread_transactions: t.Set[t.Hashable] = set()
125123
self._thread_attributes: t.Dict[t.Hashable, t.Dict[str, t.Any]] = defaultdict(dict)
126-
self._thread_connections_lock = Lock()
127124
self._thread_cursors_lock = Lock()
128125
self._thread_transactions_lock = Lock()
129-
self._shared_connection = shared_connection
130126
self._cursor_init = cursor_init
131127

132128
def get_cursor(self) -> t.Any:
@@ -138,13 +134,6 @@ def get_cursor(self) -> t.Any:
138134
self._cursor_init(self._thread_cursors[thread_id])
139135
return self._thread_cursors[thread_id]
140136

141-
def get(self) -> t.Any:
142-
thread_id = get_ident()
143-
with self._thread_connections_lock:
144-
if thread_id not in self._thread_connections:
145-
self._thread_connections[thread_id] = self._connection_factory()
146-
return self._thread_connections[thread_id]
147-
148137
def get_attribute(self, key: str) -> t.Optional[t.Any]:
149138
thread_id = get_ident()
150139
return self._thread_attributes[thread_id].get(key)
@@ -178,6 +167,28 @@ def close_cursor(self) -> None:
178167
_try_close(self._thread_cursors[thread_id], "cursor")
179168
self._thread_cursors.pop(thread_id)
180169

170+
def _discard_transaction(self, thread_id: t.Hashable) -> None:
171+
with self._thread_transactions_lock:
172+
self._thread_transactions.discard(thread_id)
173+
174+
175+
class ThreadLocalConnectionPool(_ThreadLocalBase):
176+
def __init__(
177+
self,
178+
connection_factory: t.Callable[[], t.Any],
179+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
180+
):
181+
super().__init__(connection_factory, cursor_init)
182+
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
183+
self._thread_connections_lock = Lock()
184+
185+
def get(self) -> t.Any:
186+
thread_id = get_ident()
187+
with self._thread_connections_lock:
188+
if thread_id not in self._thread_connections:
189+
self._thread_connections[thread_id] = self._connection_factory()
190+
return self._thread_connections[thread_id]
191+
181192
def close(self) -> None:
182193
thread_id = get_ident()
183194
with self._thread_cursors_lock, self._thread_connections_lock:
@@ -189,23 +200,55 @@ def close(self) -> None:
189200
self._thread_attributes.pop(thread_id, None)
190201

191202
def close_all(self, exclude_calling_thread: bool = False) -> None:
192-
if exclude_calling_thread and self._shared_connection:
193-
return
194-
195203
calling_thread_id = get_ident()
196204
with self._thread_cursors_lock, self._thread_connections_lock:
197205
for thread_id, connection in self._thread_connections.copy().items():
198206
if not exclude_calling_thread or thread_id != calling_thread_id:
199-
# NOTE: the access to the connection instance itself is not thread-safe here.
200207
_try_close(connection, "connection")
201208
self._thread_connections.pop(thread_id)
202209
self._thread_cursors.pop(thread_id, None)
203210
self._discard_transaction(thread_id)
204211
self._thread_attributes.pop(thread_id, None)
205212

206-
def _discard_transaction(self, thread_id: t.Hashable) -> None:
207-
with self._thread_transactions_lock:
208-
self._thread_transactions.discard(thread_id)
213+
214+
class ThreadLocalSharedConnectionPool(_ThreadLocalBase):
215+
def __init__(
216+
self,
217+
connection_factory: t.Callable[[], t.Any],
218+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
219+
):
220+
super().__init__(connection_factory, cursor_init)
221+
self._connection: t.Optional[t.Any] = None
222+
self._connection_lock = Lock()
223+
224+
def get(self) -> t.Any:
225+
with self._connection_lock:
226+
if self._connection is None:
227+
self._connection = self._connection_factory()
228+
return self._connection
229+
230+
def close(self) -> None:
231+
thread_id = get_ident()
232+
with self._thread_cursors_lock, self._connection_lock:
233+
if thread_id in self._thread_cursors:
234+
_try_close(self._thread_cursors[thread_id], "cursor")
235+
self._thread_cursors.pop(thread_id)
236+
self._discard_transaction(thread_id)
237+
self._thread_attributes.pop(thread_id, None)
238+
239+
def close_all(self, exclude_calling_thread: bool = False) -> None:
240+
calling_thread_id = get_ident()
241+
with self._thread_cursors_lock, self._connection_lock:
242+
for thread_id, cursor in self._thread_cursors.copy().items():
243+
if not exclude_calling_thread or thread_id != calling_thread_id:
244+
_try_close(cursor, "cursor")
245+
self._thread_cursors.pop(thread_id)
246+
self._discard_transaction(thread_id)
247+
self._thread_attributes.pop(thread_id, None)
248+
249+
if not exclude_calling_thread:
250+
_try_close(self._connection, "connection")
251+
self._connection = None
209252

210253

211254
class SingletonConnectionPool(_TransactionManagementMixin):
@@ -277,13 +320,14 @@ def create_connection_pool(
277320
shared_connection: bool = False,
278321
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
279322
) -> ConnectionPool:
280-
return (
281-
ThreadLocalConnectionPool(
282-
connection_factory, shared_connection=shared_connection, cursor_init=cursor_init
283-
)
323+
pool_class = (
324+
ThreadLocalSharedConnectionPool
325+
if multithreaded and shared_connection
326+
else ThreadLocalConnectionPool
284327
if multithreaded
285-
else SingletonConnectionPool(connection_factory, cursor_init=cursor_init)
328+
else SingletonConnectionPool
286329
)
330+
return pool_class(connection_factory, cursor_init=cursor_init)
287331

288332

289333
def _try_close(closeable: t.Any, kind: str) -> None:

tests/core/test_connection_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def test_duckdb_attach_options():
607607

608608
def test_duckdb_multithreaded_connection_factory(make_config):
609609
from sqlmesh.core.engine_adapter import DuckDBEngineAdapter
610-
from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool
610+
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
611611
from threading import Thread
612612

613613
config = make_config(type="duckdb")
@@ -620,7 +620,7 @@ def test_duckdb_multithreaded_connection_factory(make_config):
620620
config = make_config(type="duckdb", concurrent_tasks=8)
621621
adapter = config.create_engine_adapter()
622622
assert isinstance(adapter, DuckDBEngineAdapter)
623-
assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool)
623+
assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool)
624624

625625
threads = []
626626
connection_objects = []

tests/utils/test_connection_pool.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlmesh.utils.connection_pool import (
77
SingletonConnectionPool,
88
ThreadLocalConnectionPool,
9+
ThreadLocalSharedConnectionPool,
910
)
1011

1112

@@ -207,3 +208,95 @@ def thread():
207208
assert cursor_mock_thread_one.rollback.call_count == 1
208209

209210
assert cursor_mock_thread_two.begin.call_count == 1
211+
212+
213+
def test_thread_local_shared_connection_pool(mocker: MockerFixture):
214+
cursor_mock_thread_one = mocker.Mock()
215+
cursor_mock_thread_two = mocker.Mock()
216+
connection_mock = mocker.Mock()
217+
connection_mock.cursor.side_effect = [
218+
cursor_mock_thread_one,
219+
cursor_mock_thread_two,
220+
cursor_mock_thread_one,
221+
]
222+
223+
test_thread_id = get_ident()
224+
225+
connection_factory_mock = mocker.Mock(return_value=connection_mock)
226+
pool = ThreadLocalSharedConnectionPool(connection_factory_mock)
227+
228+
assert pool.get_cursor() == cursor_mock_thread_one
229+
assert pool.get_cursor() == cursor_mock_thread_one
230+
assert pool.get() == connection_mock
231+
assert pool.get() == connection_mock
232+
233+
def thread():
234+
assert pool.get_cursor() == cursor_mock_thread_two
235+
assert pool.get_cursor() == cursor_mock_thread_two
236+
assert pool.get() == connection_mock
237+
assert pool.get() == connection_mock
238+
239+
with ThreadPoolExecutor(max_workers=1) as executor:
240+
executor.submit(thread).result()
241+
242+
assert pool._connection is not None
243+
assert len(pool._thread_cursors) == 2
244+
245+
pool.close_all(exclude_calling_thread=True)
246+
247+
assert pool._connection is not None
248+
assert len(pool._thread_cursors) == 1
249+
assert test_thread_id in pool._thread_cursors
250+
251+
pool.close_cursor()
252+
pool.close()
253+
254+
assert pool.get_cursor() == cursor_mock_thread_one
255+
256+
pool.close_all()
257+
258+
assert connection_factory_mock.call_count == 1
259+
260+
assert cursor_mock_thread_one.close.call_count == 2
261+
assert connection_mock.cursor.call_count == 3
262+
assert connection_mock.close.call_count == 1
263+
264+
265+
def test_thread_local_shared_connection_pool_close(mocker: MockerFixture):
266+
connection_mock = mocker.Mock()
267+
cursor_mock = mocker.Mock()
268+
connection_mock.cursor.return_value = cursor_mock
269+
270+
connection_factory_mock = mocker.Mock(return_value=connection_mock)
271+
pool = ThreadLocalSharedConnectionPool(connection_factory_mock)
272+
273+
# First time we get a connection
274+
pool.get()
275+
pool.get()
276+
pool.get_cursor()
277+
pool.get_cursor()
278+
279+
# This shouldn't close the connection, only the cursor
280+
pool.close()
281+
pool.get()
282+
pool.get()
283+
pool.get_cursor()
284+
285+
pool.get_cursor()
286+
# This shouldn't close the connection either
287+
pool.close_all(exclude_calling_thread=True)
288+
289+
pool.get()
290+
pool.get()
291+
# Now this should close the connection
292+
pool.close_all()
293+
294+
# Re-open the connection
295+
pool.get()
296+
pool.get()
297+
# Close it again
298+
pool.close_all()
299+
300+
assert cursor_mock.close.call_count == 2
301+
assert connection_factory_mock.call_count == 2
302+
assert connection_mock.close.call_count == 2

0 commit comments

Comments
 (0)