Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
return BuiltInPlanEvaluator(
state_sync=context.state_sync,
snapshot_evaluator=context.snapshot_evaluator,
create_scheduler=context.create_scheduler,
default_catalog=context.default_catalog,
console=context.console,
Expand Down
55 changes: 19 additions & 36 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
run_tests,
)
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
from sqlmesh.utils import UniqueKeyDict, Verbosity
from sqlmesh.utils.concurrency import concurrent_apply_to_values
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import (
Expand Down Expand Up @@ -417,7 +417,7 @@ def __init__(
self.config.get_state_connection(self.gateway) or self.connection_config
)

self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None

self.console = get_console()
setattr(self.console, "dialect", self.config.dialect)
Expand Down Expand Up @@ -445,22 +445,18 @@ def engine_adapter(self) -> EngineAdapter:
self._engine_adapter = self.connection_config.create_engine_adapter()
return self._engine_adapter

def snapshot_evaluator(
self, correlation_id: t.Optional[CorrelationId] = None
) -> SnapshotEvaluator:
# Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations
if correlation_id not in self._snapshot_evaluators:
self._snapshot_evaluators[correlation_id] = SnapshotEvaluator(
@property
def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
self._snapshot_evaluator = SnapshotEvaluator(
{
gateway: adapter.with_settings(
log_level=logging.INFO, correlation_id=correlation_id
)
gateway: adapter.with_settings(log_level=logging.INFO)
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
)
return self._snapshot_evaluators[correlation_id]
return self._snapshot_evaluator

def execution_context(
self,
Expand Down Expand Up @@ -541,9 +537,7 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:

return self.create_scheduler(snapshots)

def create_scheduler(
self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None
) -> Scheduler:
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
"""Creates the built-in scheduler.

Args:
Expand All @@ -554,7 +548,7 @@ def create_scheduler(
"""
return Scheduler(
snapshots,
self.snapshot_evaluator(correlation_id),
self.snapshot_evaluator,
self.state_sync,
default_catalog=self.default_catalog,
max_workers=self.concurrent_tasks,
Expand Down Expand Up @@ -719,7 +713,7 @@ def run(
NotificationEvent.RUN_START, environment=environment
)
analytics_run_id = analytics.collector.on_run_start(
engine_type=self.snapshot_evaluator().adapter.dialect,
engine_type=self.snapshot_evaluator.adapter.dialect,
state_sync_type=self.state_sync.state_type(),
)
self._load_materializations()
Expand Down Expand Up @@ -1081,7 +1075,7 @@ def evaluate(
and not parent_snapshot.categorized
]

df = self.snapshot_evaluator().evaluate_and_fetch(
df = self.snapshot_evaluator.evaluate_and_fetch(
snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -1593,12 +1587,7 @@ def apply(
default_catalog=self.default_catalog,
console=self.console,
)
explainer.evaluate(
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
)
explainer.evaluate(plan.to_evaluatable())
return

self.notification_target_manager.notify(
Expand Down Expand Up @@ -2121,7 +2110,7 @@ def audit(
errors = []
skipped_count = 0
for snapshot in snapshots:
for audit_result in self.snapshot_evaluator().audit(
for audit_result in self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -2153,7 +2142,7 @@ def audit(
self.console.log_status_update(f"Got {error.count} results, expected 0.")
if error.query:
self.console.show_sql(
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
)

self.console.log_status_update("Done.")
Expand Down Expand Up @@ -2345,14 +2334,12 @@ def print_environment_names(self) -> None:

def close(self) -> None:
"""Releases all resources allocated by this context."""
for evaluator in self._snapshot_evaluators.values():
evaluator.close()
if self._snapshot_evaluator:
self._snapshot_evaluator.close()

if self._state_sync:
self._state_sync.close()

self._snapshot_evaluators.clear()

def _run(
self,
environment: str,
Expand Down Expand Up @@ -2403,11 +2390,7 @@ def _run(

def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
self._scheduler.create_plan_evaluator(self).evaluate(
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
circuit_breaker=circuit_breaker,
plan.to_evaluatable(), circuit_breaker=circuit_breaker
)

@python_api_analytics
Expand Down Expand Up @@ -2700,7 +2683,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
)

# Remove the expired snapshots tables
self.snapshot_evaluator().cleanup(
self.snapshot_evaluator.cleanup(
target_snapshots=cleanup_targets,
on_complete=self.console.update_cleanup_progress,
)
Expand Down
6 changes: 2 additions & 4 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class PlanEvaluator(abc.ABC):
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
"""Evaluates a plan by pushing snapshots and backfilling data.
Expand All @@ -63,7 +62,6 @@ def evaluate(

Args:
plan: The plan to evaluate.
snapshot_evaluator: The snapshot evaluator to use.
circuit_breaker: The circuit breaker to use.
"""

Expand All @@ -72,11 +70,13 @@ class BuiltInPlanEvaluator(PlanEvaluator):
def __init__(
self,
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
default_catalog: t.Optional[str],
console: t.Optional[Console] = None,
):
self.state_sync = state_sync
self.snapshot_evaluator = snapshot_evaluator
self.create_scheduler = create_scheduler
self.default_catalog = default_catalog
self.console = console or get_console()
Expand All @@ -85,11 +85,9 @@ def __init__(
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
self._circuit_breaker = circuit_breaker
self.snapshot_evaluator = snapshot_evaluator

self.console.start_plan_evaluation(plan)
analytics.collector.on_plan_apply_start(
Expand Down
2 changes: 0 additions & 2 deletions sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
from sqlmesh.utils.date import to_ts
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator


logger = logging.getLogger(__name__)
Expand All @@ -40,7 +39,6 @@ def __init__(
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
Expand Down
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SnapshotDataVersion,
SnapshotFingerprint,
)
from sqlmesh.utils import random_id, CorrelationId
from sqlmesh.utils import random_id
from sqlmesh.utils.date import TimeLike, to_date
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
from sqlmesh.core.engine_adapter.shared import CatalogSupport
Expand Down Expand Up @@ -266,12 +266,10 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
def push_plan(context: Context, plan: Plan) -> None:
plan_evaluator = BuiltInPlanEvaluator(
context.state_sync,
context.snapshot_evaluator,
context.create_scheduler,
context.default_catalog,
)
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(
CorrelationId.from_plan_id(plan.plan_id)
)
deployability_index = DeployabilityIndex.create(context.snapshots.values())
evaluatable_plan = plan.to_evaluatable()
stages = plan_stages.build_plan_stages(
Expand Down
29 changes: 1 addition & 28 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
SnapshotInfoLike,
SnapshotTableInfo,
)
from sqlmesh.utils import CorrelationId
from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
from sqlmesh.utils.pydantic import validate_string
Expand Down Expand Up @@ -1138,7 +1137,7 @@ def test_non_breaking_change_after_forward_only_in_dev(
init_and_plan_context: t.Callable, has_view_binding: bool
):
context, plan = init_and_plan_context("examples/sushi")
context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding
context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding
context.apply(plan)

model = context.get_model("sushi.waiter_revenue_by_day")
Expand Down Expand Up @@ -6794,29 +6793,3 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
# valid_from should be the epoch, valid_to should be NaT
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
assert pd.isna(row["valid_to"])


def test_plan_evaluator_correlation_id(tmp_path: Path):
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
sqls = [call[0][0] for call in mock_logger.call_args_list]
return any(f"/* {correlation_id} */" in sql for sql in sqls)

create_temp_file(
tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
)

# Case 1: Ensure that the correlation id (plan_id) is included in the SQL
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
ctx = Context(paths=[tmp_path], config=Config())
plan = ctx.plan(auto_apply=True, no_prompts=True)

correlation_id = CorrelationId.from_plan_id(plan.plan_id)
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"

assert _correlation_id_in_sqls(correlation_id, mock_logger)

# Case 2: Ensure that the previous correlation id is not included in the SQL for other operations
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
ctx.snapshot_evaluator().adapter.execute("SELECT 1")

assert not _correlation_id_in_sqls(correlation_id, mock_logger)
5 changes: 1 addition & 4 deletions tests/core/test_plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
stages as plan_stages,
)
from sqlmesh.core.snapshot import SnapshotChangeCategory
from sqlmesh.utils import CorrelationId


@pytest.fixture
Expand Down Expand Up @@ -60,13 +59,11 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):

evaluator = BuiltInPlanEvaluator(
sushi_context.state_sync,
sushi_context.snapshot_evaluator,
sushi_context.create_scheduler,
sushi_context.default_catalog,
console=sushi_context.console,
)
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(
CorrelationId.from_plan_id(plan.plan_id)
)

evaluatable_plan = plan.to_evaluatable()
stages = plan_stages.build_plan_stages(
Expand Down