Skip to content

Commit d35f04e

Browse files
committed
PR feedback
1 parent ce73829 commit d35f04e

1 file changed

Lines changed: 10 additions & 18 deletions

File tree

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import contextlib
44
import logging
55
import typing as t
6-
import threading
76

87
import pandas as pd
98
from pandas.api.types import is_datetime64_any_dtype # type: ignore
@@ -69,10 +68,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
6968
},
7069
)
7170
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
72-
73-
def __init__(self, *args: t.Any, **kwargs: t.Any):
74-
super().__init__(*args, **kwargs)
75-
self._snowpark_threadlocal = threading.local()
71+
SNOWPARK = "snowpark"
7672

7773
@contextlib.contextmanager
7874
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -109,15 +105,16 @@ def _current_warehouse(self) -> exp.Identifier:
109105
@property
110106
def snowpark(self) -> t.Optional[SnowparkSession]:
111107
if snowpark:
112-
# Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
113-
# The sessions are cleaned up when close() is called
114-
if not hasattr(self._snowpark_threadlocal, "session"):
108+
if not self._connection_pool.get_attribute(self.SNOWPARK):
109+
# Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
110+
# The sessions are cleaned up when close() is called
115111
new_session = snowpark.Session.builder.configs(
116112
{"connection": self._connection_pool.get()}
117113
).create()
118-
self._snowpark_threadlocal.session = new_session
114+
self._connection_pool.set_attribute(self.SNOWPARK, new_session)
115+
116+
return self._connection_pool.get_attribute(self.SNOWPARK)
119117

120-
return self._snowpark_threadlocal.session
121118
return None
122119

123120
@property
@@ -596,14 +593,9 @@ def _columns_to_types(
596593

597594
return super()._columns_to_types(query_or_df, columns_to_types)
598595

599-
def _cleanup_snowpark(self) -> None:
600-
if hasattr(self._snowpark_threadlocal, "session") and (
601-
session := self._snowpark_threadlocal.session
602-
):
603-
session.close()
604-
delattr(self._snowpark_threadlocal, "session")
605-
606596
def close(self) -> t.Any:
607-
self._cleanup_snowpark()
597+
if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK):
598+
snowpark_session.close() # type: ignore
599+
self._connection_pool.set_attribute(self.SNOWPARK, None)
608600

609601
return super().close()

0 commit comments

Comments
 (0)