Skip to content

Commit 909548b

Browse files
authored
feat: set query label session property in bq session (#4314)
1 parent 76f52e6 commit 909548b

5 files changed

Lines changed: 182 additions & 3 deletions

File tree

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,31 @@ def query_factory() -> Query:
183183
def _begin_session(self, properties: SessionProperties) -> None:
184184
from google.cloud.bigquery import QueryJobConfig
185185

186-
job = self.client.query("SELECT 1;", job_config=QueryJobConfig(create_session=True))
186+
query_label_property = properties.get("query_label")
187+
parsed_query_label: list[tuple[str, str]] = []
188+
if isinstance(query_label_property, (exp.Array, exp.Paren, exp.Tuple)):
189+
label_tuples = (
190+
[query_label_property.unnest()]
191+
if isinstance(query_label_property, exp.Paren)
192+
else query_label_property.expressions
193+
)
194+
195+
# query_label is a Paren, Array or Tuple of 2-tuples and validated at load time
196+
parsed_query_label.extend(
197+
(label_tuple.expressions[0].name, label_tuple.expressions[1].name)
198+
for label_tuple in label_tuples
199+
)
200+
201+
if parsed_query_label:
202+
query_label_str = ",".join([":".join(label) for label in parsed_query_label])
203+
query = f'SET @@query_label = "{query_label_str}";SELECT 1;'
204+
else:
205+
query = "SELECT 1;"
206+
207+
job = self.client.query(
208+
query,
209+
job_config=QueryJobConfig(create_session=True),
210+
)
187211
session_info = job.session_info
188212
session_id = session_info.session_id if session_info else None
189213
self._session_id = session_id

sqlmesh/core/model/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,6 @@ def _executable_to_str(k: str, v: Executable) -> str:
441441
properties_validator: t.Callable = field_validator(
442442
"physical_properties_",
443443
"virtual_properties_",
444-
"session_properties_",
445444
"materialization_properties_",
446445
mode="before",
447446
check_fields=False,

sqlmesh/core/model/meta.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
default_catalog_validator,
1818
depends_on_validator,
1919
properties_validator,
20+
parse_properties,
2021
)
2122
from sqlmesh.core.model.kind import (
2223
CustomKind,
@@ -310,6 +311,43 @@ def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expressi
310311
def ignored_rules_validator(cls, vs: t.Any) -> t.Any:
311312
return LinterConfig._validate_rules(vs)
312313

314+
@field_validator("session_properties_", mode="before")
315+
def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
316+
# use the generic properties validator to parse the session properties
317+
parsed_session_properties = parse_properties(type(cls), v, info)
318+
if not parsed_session_properties:
319+
return parsed_session_properties
320+
321+
for eq in parsed_session_properties:
322+
if eq.name == "query_label":
323+
query_label = eq.right
324+
if not (
325+
isinstance(query_label, exp.Array)
326+
or isinstance(query_label, exp.Tuple)
327+
or isinstance(query_label, exp.Paren)
328+
):
329+
raise ConfigError(
330+
"Invalid value for `session_properties.query_label`. Must be an array or tuple."
331+
)
332+
333+
label_tuples: t.List[exp.Expression] = (
334+
[query_label.unnest()]
335+
if isinstance(query_label, exp.Paren)
336+
else query_label.expressions
337+
)
338+
339+
for label_tuple in label_tuples:
340+
if not (
341+
isinstance(label_tuple, exp.Tuple)
342+
and len(label_tuple.expressions) == 2
343+
and all(isinstance(label, exp.Literal) for label in label_tuple.expressions)
344+
):
345+
raise ConfigError(
346+
"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2."
347+
)
348+
349+
return parsed_session_properties
350+
313351
@model_validator(mode="before")
314352
def _pre_root_validator(cls, data: t.Any) -> t.Any:
315353
if not isinstance(data, dict):

tests/core/engine_adapter/test_bigquery.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def test_begin_end_session(mocker: MockerFixture):
531531

532532
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)
533533

534+
# starting a session without session properties
534535
with adapter.session({}):
535536
assert adapter._connection_pool.get_attribute("session_id") is not None
536537
adapter.execute("SELECT 2;")
@@ -551,6 +552,18 @@ def test_begin_end_session(mocker: MockerFixture):
551552
assert execute_b_call[1]["query"] == "SELECT 3;"
552553
assert not execute_b_call[1]["job_config"].connection_properties
553554

555+
# starting a new session with session property query_label and array value
556+
with adapter.session({"query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]")}):
557+
adapter.execute("SELECT 4;")
558+
begin_new_session_call = connection_mock._client.query.call_args_list[3]
559+
assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1,key2:value2";SELECT 1;'
560+
561+
# starting a new session with session property query_label and Paren value
562+
with adapter.session({"query_label": parse_one("(('key1', 'value1'))")}):
563+
adapter.execute("SELECT 5;")
564+
begin_new_session_call = connection_mock._client.query.call_args_list[5]
565+
assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1";SELECT 1;'
566+
554567

555568
def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]:
556569
output = []

tests/core/test_model.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3757,7 +3757,10 @@ def my_model(context, **kwargs):
37573757
"""('key_a' = 'value_a', 'key_b' = 1, 'key_c' = TRUE, 'key_d' = 2.0)"""
37583758
)
37593759

3760-
with pytest.raises(ConfigError, match=r"Invalid property 'invalid'.*"):
3760+
with pytest.raises(
3761+
ConfigError,
3762+
match=r"Invalid property 'invalid'. Properties must be specified as key-value pairs <key> = <value>. ",
3763+
):
37613764
load_sql_based_model(
37623765
d.parse(
37633766
"""
@@ -4418,6 +4421,108 @@ def test_model_session_properties(sushi_context):
44184421
"warehouse": "test_warehouse",
44194422
}
44204423

4424+
model = load_sql_based_model(
4425+
d.parse(
4426+
"""
4427+
MODEL (
4428+
name test_schema.test_model,
4429+
session_properties (
4430+
'query_label' = [
4431+
('key1', 'value1'),
4432+
('key2', 'value2')
4433+
]
4434+
)
4435+
);
4436+
SELECT a FROM tbl;
4437+
""",
4438+
default_dialect="bigquery",
4439+
)
4440+
)
4441+
assert model.session_properties == {
4442+
"query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]")
4443+
}
4444+
4445+
model = load_sql_based_model(
4446+
d.parse(
4447+
"""
4448+
MODEL (
4449+
name test_schema.test_model,
4450+
session_properties (
4451+
'query_label' = (
4452+
('key1', 'value1')
4453+
)
4454+
)
4455+
);
4456+
SELECT a FROM tbl;
4457+
""",
4458+
default_dialect="bigquery",
4459+
)
4460+
)
4461+
assert model.session_properties == {"query_label": parse_one("(('key1', 'value1'))")}
4462+
4463+
with pytest.raises(
4464+
ConfigError,
4465+
match=r"Invalid value for `session_properties.query_label`. Must be an array or tuple.",
4466+
):
4467+
load_sql_based_model(
4468+
d.parse(
4469+
"""
4470+
MODEL (
4471+
name test_schema.test_model,
4472+
session_properties (
4473+
'query_label' = 'invalid value'
4474+
)
4475+
);
4476+
SELECT a FROM tbl;
4477+
""",
4478+
default_dialect="bigquery",
4479+
)
4480+
)
4481+
4482+
with pytest.raises(
4483+
ConfigError,
4484+
match=r"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2.",
4485+
):
4486+
load_sql_based_model(
4487+
d.parse(
4488+
"""
4489+
MODEL (
4490+
name test_schema.test_model,
4491+
session_properties (
4492+
'query_label' = (
4493+
('key1', 'value1', 'another_value')
4494+
)
4495+
)
4496+
);
4497+
SELECT a FROM tbl;
4498+
""",
4499+
default_dialect="bigquery",
4500+
)
4501+
)
4502+
4503+
with pytest.raises(
4504+
ConfigError,
4505+
match=r"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2.",
4506+
):
4507+
load_sql_based_model(
4508+
d.parse(
4509+
"""
4510+
MODEL (
4511+
name test_schema.test_model,
4512+
session_properties (
4513+
'query_label' = (
4514+
'some value',
4515+
'another value',
4516+
'yet another value',
4517+
)
4518+
)
4519+
);
4520+
SELECT a FROM tbl;
4521+
""",
4522+
default_dialect="bigquery",
4523+
)
4524+
)
4525+
44214526

44224527
def test_model_jinja_macro_rendering():
44234528
expressions = d.parse(

0 commit comments

Comments
 (0)