Skip to content

Commit be3e436

Browse files
address comments
1 parent cf67d85 commit be3e436

7 files changed

Lines changed: 49 additions & 75 deletions

File tree

sqlmesh/core/config/scheduler.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@ def create_state_sync(self, context: GenericContext) -> StateSync:
4646
The StateSync instance.
4747
"""
4848

49-
@abc.abstractmethod
50-
def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
51-
"""Returns the default catalog for the Scheduler.
52-
53-
Args:
54-
context: The SQLMesh Context.
55-
"""
56-
5749
@abc.abstractmethod
5850
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
5951
"""Returns the default catalog for each gateway.
@@ -74,7 +66,7 @@ def state_sync_fingerprint(self, context: GenericContext) -> str:
7466
class _EngineAdapterStateSyncSchedulerConfig(SchedulerConfig):
7567
def create_state_sync(self, context: GenericContext) -> StateSync:
7668
state_connection = (
77-
context.config.get_state_connection(context.gateway) or context._connection_config
69+
context.config.get_state_connection(context.gateway) or context.connection_config
7870
)
7971

8072
warehouse_connection = context.config.get_connection(context.gateway)
@@ -118,7 +110,7 @@ def create_state_sync(self, context: GenericContext) -> StateSync:
118110

119111
def state_sync_fingerprint(self, context: GenericContext) -> str:
120112
state_connection = (
121-
context.config.get_state_connection(context.gateway) or context._connection_config
113+
context.config.get_state_connection(context.gateway) or context.connection_config
122114
)
123115
return md5(
124116
[
@@ -140,19 +132,16 @@ def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
140132
state_sync=context.state_sync,
141133
snapshot_evaluator=context.snapshot_evaluator,
142134
create_scheduler=context.create_scheduler,
143-
default_catalog=self.get_default_catalog(context),
135+
default_catalog=context.default_catalog,
144136
console=context.console,
145137
)
146138

147-
def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
148-
return context.engine_adapter.default_catalog
149-
150139
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
151-
return {
152-
name: adapter.default_catalog
153-
for name, adapter in context.engine_adapters.items()
154-
if adapter.default_catalog
155-
}
140+
default_catalogs_per_gateway: t.Dict[str, str] = {}
141+
for gateway, adapter in context.engine_adapters.items():
142+
if catalog := adapter.default_catalog:
143+
default_catalogs_per_gateway[gateway] = catalog
144+
return default_catalogs_per_gateway
156145

157146

158147
SCHEDULER_CONFIG_TO_TYPE = {

sqlmesh/core/context.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def __init__(
369369
self._environment_statements: t.List[EnvironmentStatements] = []
370370
self._excluded_requirements: t.Set[str] = set()
371371
self._default_catalog: t.Optional[str] = None
372+
self._default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None
372373
self._engine_adapter: t.Optional[EngineAdapter] = None
374+
self._connection_config: t.Optional[ConnectionConfig] = None
375+
self._test_connection_config: t.Optional[ConnectionConfig] = None
373376
self._linters: t.Dict[str, Linter] = {}
374377
self._loaded: bool = False
375378

@@ -412,7 +415,7 @@ def __init__(
412415

413416
self._concurrent_tasks = concurrent_tasks
414417
self._state_connection_config = (
415-
self.config.get_state_connection(self.gateway) or self._connection_config
418+
self.config.get_state_connection(self.gateway) or self.connection_config
416419
)
417420

418421
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
@@ -440,7 +443,7 @@ def default_dialect(self) -> t.Optional[str]:
440443
def engine_adapter(self) -> EngineAdapter:
441444
"""Returns the default engine adapter."""
442445
if self._engine_adapter is None:
443-
self._engine_adapter = self._connection_config.create_engine_adapter()
446+
self._engine_adapter = self.connection_config.create_engine_adapter()
444447
return self._engine_adapter
445448

446449
@property
@@ -965,8 +968,8 @@ def requirements(self) -> t.Dict[str, str]:
965968

966969
@property
967970
def default_catalog(self) -> t.Optional[str]:
968-
if self._default_catalog is None:
969-
self._default_catalog = self._scheduler.get_default_catalog(self)
971+
if self._default_catalog is None and self.default_catalog_per_gateway:
972+
self._default_catalog = self.default_catalog_per_gateway[self.selected_gateway]
970973
return self._default_catalog
971974

972975
@python_api_analytics
@@ -1978,7 +1981,7 @@ def create_test(
19781981

19791982
try:
19801983
model_to_test = self.get_model(model, raise_if_missing=True)
1981-
test_adapter = self._test_connection_config.create_engine_adapter(
1984+
test_adapter = self.test_connection_config.create_engine_adapter(
19821985
register_comments_override=False
19831986
)
19841987

@@ -2463,7 +2466,7 @@ def _run_plan_tests(
24632466
self.console.log_test_results(
24642467
result,
24652468
test_output,
2466-
self._test_connection_config._engine_adapter.DIALECT,
2469+
self.test_connection_config._engine_adapter.DIALECT,
24672470
)
24682471
if not result.wasSuccessful():
24692472
raise PlanError(
@@ -2507,30 +2510,45 @@ def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
25072510
@cached_property
25082511
def default_catalog_per_gateway(self) -> t.Dict[str, str]:
25092512
"""Returns the default catalogs for each engine adapter."""
2510-
if self.gateway_managed_virtual_layer:
2511-
return self._scheduler.get_default_catalog_per_gateway(self)
2512-
return {}
2513+
if self._default_catalog_per_gateway is None:
2514+
self._default_catalog_per_gateway = self._scheduler.get_default_catalog_per_gateway(
2515+
self
2516+
)
2517+
return self._default_catalog_per_gateway
25132518

25142519
@cached_property
25152520
def concurrent_tasks(self) -> int:
25162521
if self._concurrent_tasks is None:
2517-
self._concurrent_tasks = self._connection_config.concurrent_tasks
2522+
self._concurrent_tasks = self.connection_config.concurrent_tasks
25182523
return self._concurrent_tasks
25192524

25202525
@cached_property
2521-
def _connection_config(self) -> ConnectionConfig:
2522-
return self.config.get_connection(self.gateway)
2526+
def connection_config(self) -> ConnectionConfig:
2527+
if self._connection_config is None:
2528+
self._connection_config = self.config.get_connection(self.selected_gateway)
2529+
return self._connection_config
25232530

25242531
@cached_property
2525-
def _test_connection_config(self) -> ConnectionConfig:
2526-
return self.config.get_test_connection(
2527-
self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT
2528-
)
2532+
def test_connection_config(self) -> ConnectionConfig:
2533+
if self._test_connection_config is None:
2534+
self._test_connection_config = self.config.get_test_connection(
2535+
self.gateway,
2536+
self.default_catalog,
2537+
default_catalog_dialect=self.engine_adapter.DIALECT,
2538+
)
2539+
return self._test_connection_config
25292540

25302541
@cached_property
25312542
def environment_catalog_mapping(self) -> RegexKeyDict:
2543+
engine_adapter = None
2544+
try:
2545+
engine_adapter = self.engine_adapter
2546+
except Exception:
2547+
pass
2548+
25322549
if (
25332550
self.config.environment_catalog_mapping
2551+
and engine_adapter
25342552
and not self.engine_adapter.catalog_support.is_multi_catalog_supported
25352553
):
25362554
raise SQLMeshError(

sqlmesh/integrations/github/cicd/controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def conclusion_handler(
726726
self._console.log_test_results(
727727
result,
728728
output,
729-
self._context._test_connection_config._engine_adapter.DIALECT,
729+
self._context.test_connection_config._engine_adapter.DIALECT,
730730
)
731731
test_summary = self._console.consume_captured_output()
732732
test_title = "Tests Passed" if result.wasSuccessful() else "Tests Failed"

tests/core/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
742742

743743
ctx = Context(paths=tmp_path, config=config)
744744

745-
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
745+
assert isinstance(ctx.connection_config, RedshiftConnectionConfig)
746746
assert len(ctx.engine_adapters) == 3
747747
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
748748
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
@@ -782,7 +782,7 @@ def test_multi_gateway_single_threaded_config(tmp_path):
782782
)
783783

784784
ctx = Context(paths=tmp_path, config=config)
785-
assert isinstance(ctx._connection_config, DuckDBConnectionConfig)
785+
assert isinstance(ctx.connection_config, DuckDBConnectionConfig)
786786
assert len(ctx.engine_adapters) == 2
787787
assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb")
788788
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)

tests/core/test_integration.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4872,13 +4872,11 @@ def test_multi(mocker):
48724872
assert context.fetchdf("select * from after_1").to_dict()["repo_1"][0] == "repo_1"
48734873
assert context.fetchdf("select * from after_2").to_dict()["repo_2"][0] == "repo_2"
48744874

4875-
adapter = context.engine_adapter
48764875
context = Context(
48774876
paths=["examples/multi/repo_1"],
48784877
state_sync=context.state_sync,
48794878
gateway="memory",
48804879
)
4881-
context._engine_adapter = adapter
48824880

48834881
model = context.get_model("bronze.a")
48844882
assert model.project == "repo_1"
@@ -4935,6 +4933,8 @@ def test_multi_virtual_layer(copy_to_temp_path):
49354933
)
49364934

49374935
context = Context(paths=paths, config=config)
4936+
assert context.default_catalog_per_gateway == {"first": "db_1", "second": "db_2"}
4937+
assert len(context.engine_adapters) == 2
49384938

49394939
# For the model without gateway the default should be used and the gateway variable should overide the global
49404940
assert (
@@ -5064,39 +5064,6 @@ def test_multi_virtual_layer(copy_to_temp_path):
50645064
context.apply(plan)
50655065

50665066

5067-
def test_multi_virtual_layer_catalogs(copy_to_temp_path):
5068-
paths = copy_to_temp_path("tests/fixtures/multi_virtual_layer")
5069-
path = Path(paths[0])
5070-
first_db_path = str(path / "db_1.db")
5071-
second_db_path = str(path / "db_2.db")
5072-
5073-
config = Config(
5074-
gateways={
5075-
"first": GatewayConfig(
5076-
connection=DuckDBConnectionConfig(database=first_db_path),
5077-
variables={"overriden_var": "gateway_1"},
5078-
),
5079-
"second": GatewayConfig(
5080-
connection=DuckDBConnectionConfig(database=second_db_path),
5081-
variables={"overriden_var": "gateway_2"},
5082-
),
5083-
},
5084-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
5085-
model_naming=NameInferenceConfig(infer_names=True),
5086-
default_gateway="first",
5087-
variables={"overriden_var": "global", "global_one": 88},
5088-
)
5089-
5090-
# With gateway_managed_virtual_layer to False the catalogs won't be retrieved
5091-
context = Context(paths=paths, config=config)
5092-
assert context.default_catalog_per_gateway == {}
5093-
5094-
config.gateway_managed_virtual_layer = True
5095-
context = Context(paths=paths, config=config)
5096-
assert context.default_catalog_per_gateway == {"first": "db_1", "second": "db_2"}
5097-
assert len(context.engine_adapters) == 2
5098-
5099-
51005067
def test_multi_dbt(mocker):
51015068
context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"])
51025069
context._new_state_sync().reset(default_catalog=context.default_catalog)

tests/core/test_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _create_test(
5454
test_name=test_name,
5555
model=model,
5656
models=context._models,
57-
engine_adapter=context._test_connection_config.create_engine_adapter(
57+
engine_adapter=context.test_connection_config.create_engine_adapter(
5858
register_comments_override=False
5959
),
6060
dialect=context.config.dialect,

web/server/api/endpoints/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def test(
159159
context.console.log_test_results(
160160
result,
161161
test_output.getvalue(),
162-
context._test_connection_config._engine_adapter.DIALECT,
162+
context.test_connection_config._engine_adapter.DIALECT,
163163
)
164164

165165
def _test_path(test: ModelTest) -> t.Optional[str]:

0 commit comments

Comments
 (0)