Skip to content

Commit dc0788d

Browse files
committed
Add tests for BigFrame and Snowpark dataframes
1 parent d09a25f commit dc0788d

6 files changed

Lines changed: 108 additions & 11 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _df_to_source_queries(
249249

250250
# we need to ensure that the order of the columns in columns_to_types columns matches the order of the values
251251
# they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model
252-
df = df[list(columns_to_types.keys())]
252+
df = df[list(columns_to_types)]
253253
values = list(df.itertuples(index=False, name=None))
254254

255255
return [

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def query_factory() -> Query:
219219
if not self.table_exists(temp_table):
220220
columns_to_types_create = columns_to_types.copy()
221221
ordered_df = df[
222-
list(columns_to_types_create.keys())
222+
list(columns_to_types_create)
223223
] # reorder DataFrame so it matches columns_to_types
224224
self._convert_df_datetime(ordered_df, columns_to_types_create)
225225
self.create_table(temp_table, columns_to_types_create)

sqlmesh/core/engine_adapter/spark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,14 @@ def _ensure_pyspark_df(
280280
if pyspark_df:
281281
if columns_to_types:
282282
# ensure Spark dataframe column order matches columns_to_types
283-
pyspark_df = pyspark_df.select(*list(columns_to_types.keys()))
283+
pyspark_df = pyspark_df.select(*list(columns_to_types))
284284
return pyspark_df
285285
df = self.try_get_pandas_df(generic_df)
286286
if df is None:
287287
raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame")
288288
if columns_to_types:
289289
# ensure Pandas dataframe column order matches columns_to_types
290-
df = df[list(columns_to_types.keys())]
290+
df = df[list(columns_to_types)]
291291
kwargs = (
292292
dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {}
293293
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,8 +2742,7 @@ def test_python_model_column_order(ctx: TestContext, tmp_path_factory: pytest.Te
27422742
pytest.skip("python model column order test only needs to be run once per db")
27432743

27442744
tmp_path = tmp_path_factory.mktemp(f"column_order_{ctx.test_id}")
2745-
2746-
test_schema = ctx.add_test_suffix("column_order")
2745+
schema = ctx.add_test_suffix(TEST_SCHEMA)
27472746

27482747
(tmp_path / "models").mkdir()
27492748

@@ -2772,7 +2771,7 @@ def execute(
27722771
return context.spark.createDataFrame([
27732772
Row(name="foo", id=1)
27742773
])
2775-
""".replace("TEST_SCHEMA", test_schema)
2774+
""".replace("TEST_SCHEMA", schema)
27762775
)
27772776
else:
27782777
# python model that emits a Pandas DataFrame
@@ -2796,7 +2795,7 @@ def execute(
27962795
return pd.DataFrame([
27972796
{"name": "foo", "id": 1}
27982797
])
2799-
""".replace("TEST_SCHEMA", test_schema)
2798+
""".replace("TEST_SCHEMA", schema)
28002799
)
28012800

28022801
sqlmesh_ctx = ctx.create_context(path=tmp_path)
@@ -2808,6 +2807,9 @@ def execute(
28082807

28092808
engine_adapter = sqlmesh_ctx.engine_adapter
28102809

2811-
df = engine_adapter.fetchdf(f"select * from {test_schema}.model")
2810+
query = exp.select("*").from_(
2811+
exp.to_table(f"{schema}.model", dialect=ctx.dialect), dialect=ctx.dialect
2812+
)
2813+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
28122814
assert len(df) == 1
28132815
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_bigquery.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sqlmesh.core.model import SqlModel, load_sql_based_model
1414
from sqlmesh.core.plan import Plan
1515
from sqlmesh.core.table_diff import TableDiff
16-
from tests.core.engine_adapter.integration import TestContext
16+
from tests.core.engine_adapter.integration import TestContext, TEST_SCHEMA
1717

1818
pytestmark = [pytest.mark.engine, pytest.mark.remote, pytest.mark.bigquery]
1919

@@ -433,3 +433,51 @@ def test_table_diff_table_name_matches_column_name(ctx: TestContext):
433433

434434
assert row_diff.stats["join_count"] == 1
435435
assert row_diff.full_match_count == 1
436+
437+
438+
def test_bigframe_python_model_column_order(ctx: TestContext, tmp_path: Path):
439+
schema = ctx.add_test_suffix(TEST_SCHEMA)
440+
441+
(tmp_path / "models").mkdir()
442+
443+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
444+
# is returned by the DataFrame within the model
445+
model_path = tmp_path / "models" / "python_model.py"
446+
447+
# python model that emits a BigFrame dataframe
448+
model_path.write_text(
449+
"""
450+
from bigframes.pandas import DataFrame
451+
import typing as t
452+
from sqlmesh import ExecutionContext, model
453+
454+
@model(
455+
"TEST_SCHEMA.model",
456+
columns={
457+
"id": "int",
458+
"name": "varchar"
459+
}
460+
)
461+
def execute(
462+
context: ExecutionContext,
463+
**kwargs: t.Any,
464+
) -> DataFrame:
465+
return DataFrame({'name': ['foo'], 'id': [1]})
466+
""".replace("TEST_SCHEMA", schema)
467+
)
468+
469+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
470+
471+
assert len(sqlmesh_ctx.models) == 1
472+
473+
plan = sqlmesh_ctx.plan(auto_apply=True)
474+
assert len(plan.new_snapshots) == 1
475+
476+
engine_adapter = sqlmesh_ctx.engine_adapter
477+
478+
query = exp.select("*").from_(
479+
exp.to_table(f"{schema}.model", dialect=ctx.dialect), dialect=ctx.dialect
480+
)
481+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
482+
assert len(df) == 1
483+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import typing as t
22
import pytest
33
from sqlglot import exp
4+
from pathlib import Path
45
from sqlglot.optimizer.qualify_columns import quote_identifiers
56
from sqlglot.helper import seq_get
67
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
78
from sqlmesh.core.engine_adapter.shared import DataObject
89
import sqlmesh.core.dialect as d
910
from sqlmesh.core.model import SqlModel, load_sql_based_model
1011
from sqlmesh.core.plan import Plan
11-
from tests.core.engine_adapter.integration import TestContext
12+
from tests.core.engine_adapter.integration import TestContext, TEST_SCHEMA
1213

1314
pytestmark = [pytest.mark.engine, pytest.mark.remote, pytest.mark.snowflake]
1415

@@ -210,3 +211,49 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA
210211
result = sqlmesh.plan(auto_apply=True)
211212

212213
assert len(result.new_snapshots) == 2
214+
215+
216+
def test_snowpark_python_model_column_order(ctx: TestContext, tmp_path: Path):
217+
schema = ctx.add_test_suffix(TEST_SCHEMA)
218+
219+
(tmp_path / "models").mkdir()
220+
221+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
222+
# is returned by the DataFrame within the model
223+
model_path = tmp_path / "models" / "python_model.py"
224+
225+
# python model that emits a Snowpark DataFrame
226+
model_path.write_text(
227+
"""
228+
from snowflake.snowpark.dataframe import DataFrame
229+
import typing as t
230+
from sqlmesh import ExecutionContext, model
231+
232+
@model(
233+
"TEST_SCHEMA.model",
234+
columns={
235+
"id": "int",
236+
"name": "varchar"
237+
}
238+
)
239+
def execute(
240+
context: ExecutionContext,
241+
**kwargs: t.Any,
242+
) -> DataFrame:
243+
return context.snowpark.create_dataframe([["foo", 1]], schema=["name", "id"])
244+
""".replace("TEST_SCHEMA", schema)
245+
)
246+
247+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
248+
249+
assert len(sqlmesh_ctx.models) == 1
250+
251+
plan = sqlmesh_ctx.plan(auto_apply=True)
252+
assert len(plan.new_snapshots) == 1
253+
254+
engine_adapter = sqlmesh_ctx.engine_adapter
255+
256+
query = exp.select("*").from_(exp.to_table(f"{schema}.model", dialect=ctx.dialect))
257+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
258+
assert len(df) == 1
259+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

0 commit comments

Comments
 (0)