Skip to content

Commit cf122b9

Browse files
Feat: Add condition in the union macro for conditional union of tables (#4337)
1 parent 91a92e4 commit cf122b9

3 files changed

Lines changed: 217 additions & 3 deletions

File tree

docs/concepts/macros/sqlmesh_macros.md

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,9 @@ FROM foo
855855

856856
`@UNION` returns a `UNION` query that selects all columns with matching names and data types from the tables.
857857

858-
Its first argument is the `UNION` "type", `'DISTINCT'` (removing duplicated rows) or `'ALL'` (returning all rows). Subsequent arguments are the tables to be combined.
858+
Its first argument can be either a condition or the `UNION` "type". If the first argument evaluates to a boolean (`TRUE` or `FALSE`), it's treated as a condition. If the condition is `FALSE`, only the first table is returned. If it's `TRUE`, the union operation is performed.
859+
860+
If the first argument is not a boolean condition, it's treated as the `UNION` "type": either `'DISTINCT'` (removing duplicated rows) or `'ALL'` (returning all rows). Subsequent arguments are the tables to be combined.
859861

860862
Let's assume that:
861863

@@ -882,6 +884,47 @@ SELECT
882884
FROM bar
883885
```
884886

887+
If the union type is omitted, `'ALL'` is used as the default. So the following expression:
888+
889+
```sql linenums="1"
890+
@UNION(foo, bar)
891+
```
892+
893+
would be rendered as:
894+
895+
```sql linenums="1"
896+
SELECT
897+
CAST(a AS INT) AS a,
898+
CAST(c AS TEXT) AS c
899+
FROM foo
900+
UNION ALL
901+
SELECT
902+
CAST(a AS INT) AS a,
903+
CAST(c AS TEXT) AS c
904+
FROM bar
905+
```
906+
907+
You can also use a condition to control whether the union happens:
908+
909+
```sql linenums="1"
910+
@UNION(1 > 0, 'all', foo, bar)
911+
```
912+
913+
This would render the same as above. However, if the condition is `FALSE`:
914+
915+
```sql linenums="1"
916+
@UNION(1 > 2, 'all', foo, bar)
917+
```
918+
919+
Only the first table would be selected:
920+
921+
```sql linenums="1"
922+
SELECT
923+
CAST(a AS INT) AS a,
924+
CAST(c AS TEXT) AS c
925+
FROM foo
926+
```
927+
885928
### @HAVERSINE_DISTANCE
886929

887930
`@HAVERSINE_DISTANCE` returns the [haversine distance](https://en.wikipedia.org/wiki/Haversine_formula) between two geographic points.

sqlmesh/core/macros.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,23 +970,54 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
970970
@macro()
971971
def union(
972972
evaluator: MacroEvaluator,
973-
type_: exp.Literal = exp.Literal.string("ALL"),
974-
*tables: exp.Table,
973+
*args: exp.Expression,
975974
) -> exp.Query:
976975
"""Returns a UNION of the given tables. Only choosing columns that have the same name and type.
977976
977+
Args:
978+
evaluator: MacroEvaluator that invoked the macro
979+
args: Variable arguments that can be:
980+
- First argument can be a condition (exp.Condition)
981+
- A union type ('ALL' or 'DISTINCT') as exp.Literal
982+
- Tables (exp.Table)
983+
978984
Example:
979985
>>> from sqlglot import parse_one
980986
>>> from sqlglot.schema import MappingSchema
981987
>>> from sqlmesh.core.macros import MacroEvaluator
982988
>>> sql = "@UNION('distinct', foo, bar)"
983989
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql()
984990
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
991+
>>> sql = "@UNION(True, 'distinct', foo, bar)"
992+
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql()
993+
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
985994
"""
995+
996+
if not args:
997+
raise SQLMeshError("At least one table is required for the @UNION macro.")
998+
999+
arg_idx = 0
1000+
# Check for condition
1001+
condition = evaluator.eval_expression(args[arg_idx])
1002+
if isinstance(condition, bool):
1003+
arg_idx += 1
1004+
if arg_idx >= len(args):
1005+
raise SQLMeshError("Expected more arguments after the condition of the `@UNION` macro.")
1006+
1007+
# Check for union type
1008+
type_ = exp.Literal.string("ALL")
1009+
if isinstance(args[arg_idx], exp.Literal):
1010+
type_ = args[arg_idx] # type: ignore
1011+
arg_idx += 1
9861012
kind = type_.name.upper()
9871013
if kind not in ("ALL", "DISTINCT"):
9881014
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
9891015

1016+
# Remaining args should be tables
1017+
tables = [
1018+
exp.to_table(e.sql(evaluator.dialect), dialect=evaluator.dialect) for e in args[arg_idx:]
1019+
]
1020+
9901021
columns = {
9911022
column
9921023
for column, _ in reduce(
@@ -1001,6 +1032,10 @@ def union(
10011032
if column in columns
10021033
]
10031034

1035+
# Skip the union if condition is False
1036+
if condition == False:
1037+
return exp.select(*projections).from_(tables[0])
1038+
10041039
return reduce(
10051040
lambda a, b: a.union(b, distinct=kind == "DISTINCT"), # type: ignore
10061041
[exp.select(*projections).from_(t) for t in tables],

tests/core/test_model.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,142 @@ def test_model_union_query(sushi_context, assert_exp_eq):
266266
)
267267

268268

269+
@time_machine.travel("1996-02-10 00:00:00 UTC")
270+
@pytest.mark.parametrize(
271+
"test_id, condition, union_type, table_count, expected_result",
272+
[
273+
# Test case 1: Basic conditional union - True condition
274+
(
275+
"test_1",
276+
"@get_date() == '1996-02-10'",
277+
"'all'",
278+
2,
279+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n",
280+
),
281+
# Test case 2: False condition - should return just first table
282+
(
283+
"test_2",
284+
"@get_date() > '1996-02-10'",
285+
"'all'",
286+
2,
287+
lambda expected_select: f"{expected_select}\n",
288+
),
289+
# Test case 3: Multiple tables in union
290+
(
291+
"test_3",
292+
"@get_date() == '1996-02-10'",
293+
"'all'",
294+
3,
295+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n",
296+
),
297+
# Test case 4: DISTINCT type
298+
(
299+
"test_4",
300+
"@get_date() == '1996-02-10'",
301+
"'distinct'",
302+
2,
303+
lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n",
304+
),
305+
# Test case 5: Complex condition
306+
(
307+
"test_5",
308+
"@get_date() = '1996-02-10' and 1=1 or @get_date() > '1996-02-10'",
309+
"'distinct'",
310+
2,
311+
lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n",
312+
),
313+
# Test case 6: Missing union type (defaults to ALL)
314+
(
315+
"test_6",
316+
"@get_date() == '1996-02-10'",
317+
"",
318+
2,
319+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n",
320+
),
321+
# Test case 7: Missing union type AND condition
322+
(
323+
"test_7",
324+
"",
325+
"",
326+
2,
327+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n",
328+
),
329+
# Test case 8: Missing union type AND condition multiple tables
330+
(
331+
"test_8",
332+
"",
333+
"",
334+
3,
335+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n",
336+
),
337+
# Test case 9: Missing union type AND condition one table
338+
(
339+
"test_9",
340+
"",
341+
"",
342+
1,
343+
lambda expected_select: f"{expected_select}",
344+
),
345+
# Test case 10: Union type with one table
346+
(
347+
"test_10",
348+
"",
349+
"'distinct'",
350+
1,
351+
lambda expected_select: f"{expected_select}",
352+
),
353+
# Test case 11: Condition with one table
354+
(
355+
"test_9",
356+
"True",
357+
"",
358+
1,
359+
lambda expected_select: f"{expected_select}",
360+
),
361+
],
362+
)
363+
def test_model_union_conditional(
364+
sushi_context, assert_exp_eq, test_id, condition, union_type, table_count, expected_result
365+
):
366+
@macro()
367+
def get_date(evaluator):
368+
from sqlmesh.utils.date import now
369+
370+
return f"'{now().date()}'"
371+
372+
expected_select = """SELECT
373+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
374+
CAST("marketing"."status" AS TEXT) AS "status",
375+
CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at",
376+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
377+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
378+
FROM "memory"."sushi"."marketing" AS "marketing"
379+
"""
380+
381+
# Create tables argument list based on table_count
382+
tables = ", ".join(["sushi.marketing"] * table_count)
383+
384+
# Handle the missing union_type case
385+
union_type_arg = f", {union_type}" if union_type else ""
386+
387+
expressions = d.parse(
388+
f"""
389+
MODEL (
390+
name sushi.{test_id},
391+
kind FULL,
392+
);
393+
394+
@union({condition}{union_type_arg}, {tables})
395+
"""
396+
)
397+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
398+
399+
assert_exp_eq(
400+
sushi_context.get_model(f"sushi.{test_id}").render_query(),
401+
expected_result(expected_select),
402+
)
403+
404+
269405
def test_model_validation_union_query():
270406
expressions = d.parse(
271407
"""

0 commit comments

Comments
 (0)