Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def merge(
if merge_filter:
on = exp.and_(merge_filter, on)

match_expressions = []
if not when_matched:
match_condition = None
unique_key_names = [y.name for y in unique_key]
Expand All @@ -230,23 +231,24 @@ def merge(
)
)

match_expressions = [
exp.When(
matched=True,
source=False,
condition=match_condition,
then=exp.Update(
expressions=[
exp.column(col, MERGE_TARGET_ALIAS).eq(
exp.column(col, MERGE_SOURCE_ALIAS)
)
for col in columns_to_types_no_keys
],
),
if target_columns_no_keys:
match_expressions.append(
exp.When(
matched=True,
source=False,
condition=match_condition,
then=exp.Update(
expressions=[
exp.column(col, MERGE_TARGET_ALIAS).eq(
exp.column(col, MERGE_SOURCE_ALIAS)
)
for col in columns_to_types_no_keys
],
),
)
)
]
else:
match_expressions = when_matched.copy().expressions
match_expressions.extend(when_matched.copy().expressions)

match_expressions.append(
exp.When(
Expand Down
26 changes: 26 additions & 0 deletions tests/core/engine_adapter/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def test_merge_pandas(
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)

df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]})

# 1 key
adapter.merge(
target_table=table_name,
source_table=df,
Expand All @@ -476,6 +478,7 @@ def test_merge_pandas(
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
]

# 2 keys
adapter.cursor.reset_mock()
adapter._connection_pool.get().reset_mock()
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
Expand All @@ -499,6 +502,29 @@ def test_merge_pandas(
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
]

# all model columns are keys
adapter.cursor.reset_mock()
adapter._connection_pool.get().reset_mock()
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
adapter.merge(
target_table=table_name,
source_table=df,
columns_to_types={
"id": exp.DataType.build("int"),
"ts": exp.DataType.build("TIMESTAMP"),
},
unique_key=[exp.to_identifier("id"), exp.to_column("ts")],
)
adapter._connection_pool.get().bulk_copy.assert_called_with(
f"__temp_target_{temp_table_id}", [(1, 1), (2, 2), (3, 3)]
)

assert to_sql_calls(adapter) == [
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2)');""",
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN NOT MATCHED THEN INSERT ([id], [ts]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts]);",
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
]


def test_replace_query(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(MSSQLEngineAdapter)
Expand Down