diff --git a/pyproject.toml b/pyproject.toml index 83432b4..121b9cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ds-platform-utils" -version = "0.4.1" +version = "0.4.2" description = "Utility library for Pattern Data Science." readme = "README.md" authors = [ diff --git a/src/ds_platform_utils/metaflow/snowflake_connection.py b/src/ds_platform_utils/metaflow/snowflake_connection.py index 2b41107..e9828a5 100644 --- a/src/ds_platform_utils/metaflow/snowflake_connection.py +++ b/src/ds_platform_utils/metaflow/snowflake_connection.py @@ -80,11 +80,17 @@ def _create_snowflake_connection( conn: SnowflakeConnection = Snowflake( integration=SNOWFLAKE_INTEGRATION, client_session_keep_alive=True, - warehouse=warehouse, timezone="UTC" if use_utc else None, session_parameters={"QUERY_TAG": query_tag}, ).cn # type: ignore[attr-defined] + # Doing this in the connection parameters result in silently failing to set the warehouse, + # so we have to execute a raw query to set it. + try: + conn.execute_string("USE WAREHOUSE {}".format(warehouse)) + except Exception as e: + raise RuntimeError(f"Failed to set Snowflake warehouse to {warehouse}: {e}") from e + return conn diff --git a/tests/unit_tests/snowflake/test__execute_sql.py b/tests/unit_tests/snowflake/test__execute_sql.py index b328988..1c8ce7b 100644 --- a/tests/unit_tests/snowflake/test__execute_sql.py +++ b/tests/unit_tests/snowflake/test__execute_sql.py @@ -1,36 +1,21 @@ """Functional test for _execute_sql.""" from typing import Generator -from unittest.mock import MagicMock import pytest from snowflake.connector import SnowflakeConnection from ds_platform_utils._snowflake.run_query import _execute_sql -from ds_platform_utils.metaflow.snowflake_connection import _create_snowflake_connection +from ds_platform_utils.metaflow.snowflake_connection import get_snowflake_connection @pytest.fixture(scope="module") -def patched_current() -> Generator[MagicMock, None, None]: - """Patch Metaflow `current` object for modules used in this test file.""" - mock_current = MagicMock("metaflow.current") - mock_current.tags = ["ds.domain:testing", "ds.project:unit-tests"] - mock_current.flow_name = "DummyFlow" - mock_current.project_name = "dummy-project" - mock_current.step_name = "dummy-step" - mock_current.run_id = "123" - mock_current.username = "tester" - mock_current.is_production = False - mock_current.namespace = "user:tester" - mock_current.is_running_flow = True - mock_current.card = [] - yield mock_current - - -@pytest.fixture(scope="module") -def snowflake_conn(patched_current) -> Generator[SnowflakeConnection, None, None]: +def snowflake_conn() -> Generator[SnowflakeConnection, None, None]: """Get a Snowflake connection for testing.""" - yield _create_snowflake_connection(warehouse=None, use_utc=True) + from metaflow import current + + current.is_production = False # Ensure we're in non-prod for testing + yield get_snowflake_connection(warehouse=None, use_utc=True) def test_execute_sql_empty_string(snowflake_conn): diff --git a/uv.lock b/uv.lock index c15589d..a099e72 100644 --- a/uv.lock +++ b/uv.lock @@ -479,7 +479,7 @@ wheels = [ [[package]] name = "ds-platform-utils" -version = "0.4.1" +version = "0.4.2" source = { editable = "." } dependencies = [ { name = "jinja2" },