Skip to content

Commit 7b581be

Browse files
committed
Move get_target_environment to ContextDiff
1 parent 3326b0c commit 7b581be

2 files changed

Lines changed: 30 additions & 25 deletions

File tree

sqlmesh/core/context.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,25 +2599,6 @@ def _snapshots(
25992599

26002600
return {name: stored_snapshots.get(s.snapshot_id, s) for name, s in snapshots.items()}
26012601

2602-
def _get_target_environment(self, environment: t.Optional[str] = None) -> t.Tuple[str, str]:
2603-
environment = environment or self.config.default_target_environment
2604-
environment = Environment.sanitize_name(environment)
2605-
2606-
initial_environment = environment
2607-
2608-
if self.config.plan.always_compare_against_prod:
2609-
prod = self.state_reader.get_environment(c.PROD)
2610-
if prod:
2611-
logger.warning(
2612-
f"Comparing against production environment instead of {environment}. Note that this may lead to "
2613-
"additional backfills as accumulated changes are still pushed to the target environment."
2614-
)
2615-
environment = c.PROD
2616-
else:
2617-
environment = environment or c.PROD
2618-
2619-
return environment.lower(), initial_environment.lower()
2620-
26212602
def _context_diff(
26222603
self,
26232604
environment: str,
@@ -2627,13 +2608,13 @@ def _context_diff(
26272608
ensure_finalized_snapshots: bool = False,
26282609
diff_rendered: bool = False,
26292610
) -> ContextDiff:
2630-
target_environment, initial_environment = self._get_target_environment(environment)
2611+
environment = Environment.sanitize_name(environment)
26312612

26322613
if force_no_diff:
26332614
return ContextDiff.create_no_diff(environment, self.state_reader)
26342615

26352616
return ContextDiff.create(
2636-
environment=target_environment,
2617+
environment=environment,
26372618
snapshots=snapshots or self.snapshots,
26382619
create_from=create_from or c.PROD,
26392620
state_reader=self.state_reader,
@@ -2644,7 +2625,7 @@ def _context_diff(
26442625
environment_statements=self._environment_statements,
26452626
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26462627
infer_python_dependencies=self.config.infer_python_dependencies,
2647-
initial_environment=initial_environment,
2628+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
26482629
)
26492630

26502631
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import sys
1616
import typing as t
17+
import logging
18+
1719
from difflib import ndiff, unified_diff
1820
from functools import cached_property
1921
from sqlmesh.core import constants as c
@@ -38,6 +40,8 @@
3840

3941
IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
4042

43+
logger = logging.getLogger(__name__)
44+
4145

4246
class ContextDiff(PydanticModel):
4347
"""ContextDiff is an object representing the difference between two environments.
@@ -106,6 +110,7 @@ def create(
106110
gateway_managed_virtual_layer: bool = False,
107111
infer_python_dependencies: bool = True,
108112
initial_environment: t.Optional[str] = None,
113+
always_compare_against_prod: bool = False,
109114
) -> ContextDiff:
110115
"""Create a ContextDiff object.
111116
@@ -130,10 +135,12 @@ def create(
130135
Returns:
131136
The ContextDiff object.
132137
"""
133-
environment = environment.lower()
134-
env = state_reader.get_environment(environment)
138+
initial_environment = environment
139+
environment = _get_target_environment(
140+
environment, state_reader, always_compare_against_prod
141+
)
135142

136-
initial_environment = initial_environment or environment
143+
env = state_reader.get_environment(environment)
137144
initial_env = (
138145
env
139146
if initial_environment == environment
@@ -492,6 +499,23 @@ def text_diff(self, name: str) -> str:
492499
return ""
493500

494501

502+
def _get_target_environment(
503+
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
504+
) -> str:
505+
if always_compare_against_prod:
506+
prod = state_reader.get_environment(c.PROD)
507+
if prod:
508+
logger.warning(
509+
f"Comparing against production environment instead of {environment}. Note that this may lead to "
510+
"additional backfills as accumulated changes are still pushed to the target environment."
511+
)
512+
environment = c.PROD
513+
else:
514+
environment = environment or c.PROD
515+
516+
return environment.lower()
517+
518+
495519
def _build_requirements(
496520
provided_requirements: t.Dict[str, str],
497521
excluded_requirements: t.Set[str],

0 commit comments

Comments
 (0)