diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index fc44d8f356..5cbfc6a71c 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -130,6 +130,7 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig) def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: return BuiltInPlanEvaluator( state_sync=context.state_sync, + snapshot_evaluator=context.snapshot_evaluator, create_scheduler=context.create_scheduler, default_catalog=context.default_catalog, console=context.console, diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 1be6ca1dac..f7f068d6f9 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -116,7 +116,7 @@ run_tests, ) from sqlmesh.core.user import User -from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId +from sqlmesh.utils import UniqueKeyDict, Verbosity from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( @@ -417,7 +417,7 @@ def __init__( self.config.get_state_connection(self.gateway) or self.connection_config ) - self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {} + self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None self.console = get_console() setattr(self.console, "dialect", self.config.dialect) @@ -445,22 +445,18 @@ def engine_adapter(self) -> EngineAdapter: self._engine_adapter = self.connection_config.create_engine_adapter() return self._engine_adapter - def snapshot_evaluator( - self, correlation_id: t.Optional[CorrelationId] = None - ) -> SnapshotEvaluator: - # Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations - if correlation_id not in self._snapshot_evaluators: - self._snapshot_evaluators[correlation_id] = SnapshotEvaluator( + @property + def snapshot_evaluator(self) -> SnapshotEvaluator: + if not self._snapshot_evaluator: + self._snapshot_evaluator = SnapshotEvaluator( { - gateway: adapter.with_settings( - log_level=logging.INFO, correlation_id=correlation_id - ) + gateway: adapter.with_settings(log_level=logging.INFO) for gateway, adapter in self.engine_adapters.items() }, ddl_concurrent_tasks=self.concurrent_tasks, selected_gateway=self.selected_gateway, ) - return self._snapshot_evaluators[correlation_id] + return self._snapshot_evaluator def execution_context( self, @@ -541,9 +537,7 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: return self.create_scheduler(snapshots) - def create_scheduler( - self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None - ) -> Scheduler: + def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler: """Creates the built-in scheduler. Args: @@ -554,7 +548,7 @@ def create_scheduler( """ return Scheduler( snapshots, - self.snapshot_evaluator(correlation_id), + self.snapshot_evaluator, self.state_sync, default_catalog=self.default_catalog, max_workers=self.concurrent_tasks, @@ -719,7 +713,7 @@ def run( NotificationEvent.RUN_START, environment=environment ) analytics_run_id = analytics.collector.on_run_start( - engine_type=self.snapshot_evaluator().adapter.dialect, + engine_type=self.snapshot_evaluator.adapter.dialect, state_sync_type=self.state_sync.state_type(), ) self._load_materializations() @@ -1081,7 +1075,7 @@ def evaluate( and not parent_snapshot.categorized ] - df = self.snapshot_evaluator().evaluate_and_fetch( + df = self.snapshot_evaluator.evaluate_and_fetch( snapshot, start=start, end=end, @@ -1593,12 +1587,7 @@ def apply( default_catalog=self.default_catalog, console=self.console, ) - explainer.evaluate( - plan.to_evaluatable(), - snapshot_evaluator=self.snapshot_evaluator( - correlation_id=CorrelationId.from_plan_id(plan.plan_id) - ), - ) + explainer.evaluate(plan.to_evaluatable()) return self.notification_target_manager.notify( @@ -2121,7 +2110,7 @@ def audit( errors = [] skipped_count = 0 for snapshot in snapshots: - for audit_result in self.snapshot_evaluator().audit( + for audit_result in self.snapshot_evaluator.audit( snapshot=snapshot, start=start, end=end, @@ -2153,7 +2142,7 @@ def audit( self.console.log_status_update(f"Got {error.count} results, expected 0.") if error.query: self.console.show_sql( - f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}" + f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}" ) self.console.log_status_update("Done.") @@ -2345,14 +2334,12 @@ def print_environment_names(self) -> None: def close(self) -> None: """Releases all resources allocated by this context.""" - for evaluator in self._snapshot_evaluators.values(): - evaluator.close() + if self._snapshot_evaluator: + self._snapshot_evaluator.close() if self._state_sync: self._state_sync.close() - self._snapshot_evaluators.clear() - def _run( self, environment: str, @@ -2403,11 +2390,7 @@ def _run( def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None: self._scheduler.create_plan_evaluator(self).evaluate( - plan.to_evaluatable(), - snapshot_evaluator=self.snapshot_evaluator( - correlation_id=CorrelationId.from_plan_id(plan.plan_id) - ), - circuit_breaker=circuit_breaker, + plan.to_evaluatable(), circuit_breaker=circuit_breaker ) @python_api_analytics @@ -2700,7 +2683,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None: ) # Remove the expired snapshots tables - self.snapshot_evaluator().cleanup( + self.snapshot_evaluator.cleanup( target_snapshots=cleanup_targets, on_complete=self.console.update_cleanup_progress, ) diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 562f2ed60e..a8e2aa7919 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -51,7 +51,6 @@ class PlanEvaluator(abc.ABC): def evaluate( self, plan: EvaluatablePlan, - snapshot_evaluator: SnapshotEvaluator, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: """Evaluates a plan by pushing snapshots and backfilling data. @@ -63,7 +62,6 @@ def evaluate( Args: plan: The plan to evaluate. - snapshot_evaluator: The snapshot evaluator to use. circuit_breaker: The circuit breaker to use. """ @@ -72,11 +70,13 @@ class BuiltInPlanEvaluator(PlanEvaluator): def __init__( self, state_sync: StateSync, + snapshot_evaluator: SnapshotEvaluator, create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler], default_catalog: t.Optional[str], console: t.Optional[Console] = None, ): self.state_sync = state_sync + self.snapshot_evaluator = snapshot_evaluator self.create_scheduler = create_scheduler self.default_catalog = default_catalog self.console = console or get_console() @@ -85,11 +85,9 @@ def __init__( def evaluate( self, plan: EvaluatablePlan, - snapshot_evaluator: SnapshotEvaluator, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: self._circuit_breaker = circuit_breaker - self.snapshot_evaluator = snapshot_evaluator self.console.start_plan_evaluation(plan) analytics.collector.on_plan_apply_start( diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index d3c6480f74..ee829aeac1 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -20,7 +20,6 @@ from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator logger = logging.getLogger(__name__) @@ -40,7 +39,6 @@ def __init__( def evaluate( self, plan: EvaluatablePlan, - snapshot_evaluator: SnapshotEvaluator, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog) diff --git a/tests/conftest.py b/tests/conftest.py index a874bd7590..574c802c0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,7 +42,7 @@ SnapshotDataVersion, SnapshotFingerprint, ) -from sqlmesh.utils import random_id, CorrelationId +from sqlmesh.utils import random_id from sqlmesh.utils.date import TimeLike, to_date from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path from sqlmesh.core.engine_adapter.shared import CatalogSupport @@ -266,12 +266,10 @@ def duck_conn() -> duckdb.DuckDBPyConnection: def push_plan(context: Context, plan: Plan) -> None: plan_evaluator = BuiltInPlanEvaluator( context.state_sync, + context.snapshot_evaluator, context.create_scheduler, context.default_catalog, ) - plan_evaluator.snapshot_evaluator = context.snapshot_evaluator( - CorrelationId.from_plan_id(plan.plan_id) - ) deployability_index = DeployabilityIndex.create(context.snapshots.values()) evaluatable_plan = plan.to_evaluatable() stages = plan_stages.build_plan_stages( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index f68cb7ac47..766a788ac8 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -67,7 +67,6 @@ SnapshotInfoLike, SnapshotTableInfo, ) -from sqlmesh.utils import CorrelationId from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError from sqlmesh.utils.pydantic import validate_string @@ -1138,7 +1137,7 @@ def test_non_breaking_change_after_forward_only_in_dev( init_and_plan_context: t.Callable, has_view_binding: bool ): context, plan = init_and_plan_context("examples/sushi") - context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding + context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding context.apply(plan) model = context.get_model("sushi.waiter_revenue_by_day") @@ -6794,29 +6793,3 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call # valid_from should be the epoch, valid_to should be NaT assert str(row["valid_from"]) == "1970-01-01 00:00:00" assert pd.isna(row["valid_to"]) - - -def test_plan_evaluator_correlation_id(tmp_path: Path): - def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger): - sqls = [call[0][0] for call in mock_logger.call_args_list] - return any(f"/* {correlation_id} */" in sql for sql in sqls) - - create_temp_file( - tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col" - ) - - # Case 1: Ensure that the correlation id (plan_id) is included in the SQL - with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: - ctx = Context(paths=[tmp_path], config=Config()) - plan = ctx.plan(auto_apply=True, no_prompts=True) - - correlation_id = CorrelationId.from_plan_id(plan.plan_id) - assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}" - - assert _correlation_id_in_sqls(correlation_id, mock_logger) - - # Case 2: Ensure that the previous correlation id is not included in the SQL for other operations - with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: - ctx.snapshot_evaluator().adapter.execute("SELECT 1") - - assert not _correlation_id_in_sqls(correlation_id, mock_logger) diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index 467c3e60bd..a3735b08ed 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -11,7 +11,6 @@ stages as plan_stages, ) from sqlmesh.core.snapshot import SnapshotChangeCategory -from sqlmesh.utils import CorrelationId @pytest.fixture @@ -60,13 +59,11 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): evaluator = BuiltInPlanEvaluator( sushi_context.state_sync, + sushi_context.snapshot_evaluator, sushi_context.create_scheduler, sushi_context.default_catalog, console=sushi_context.console, ) - evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator( - CorrelationId.from_plan_id(plan.plan_id) - ) evaluatable_plan = plan.to_evaluatable() stages = plan_stages.build_plan_stages(