Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
import logging
import textwrap
from itertools import zip_longest
from pathlib import Path
from hyperscript import h
from rich.console import Console as RichConsole
Expand All @@ -26,6 +27,7 @@
from rich.tree import Tree
from sqlglot import exp

from sqlmesh.core.test.result import ModelTextTestResult
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
from sqlmesh.core.linter.rule import RuleViolation
from sqlmesh.core.model import Model
Expand All @@ -46,6 +48,7 @@
NodeAuditsErrors,
format_destructive_change_msg,
)
from sqlmesh.utils.rich import strip_ansi_codes

if t.TYPE_CHECKING:
import ipywidgets as widgets
Expand Down Expand Up @@ -316,6 +319,17 @@ def log_destructive_change(
"""Display a destructive change error or warning to the user."""


class UnitTestConsole(abc.ABC):
@abc.abstractmethod
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
"""Display the test result and output.

Args:
result: The unittest test result that contains metrics like num success, fails, ect.
target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect.
"""


class Console(
PlanBuilderConsole,
LinterConsole,
Expand All @@ -327,6 +341,7 @@ class Console(
DifferenceConsole,
TableDiffConsole,
BaseConsole,
UnitTestConsole,
abc.ABC,
):
"""Abstract base class for defining classes used for displaying information to the user and also interact
Expand Down Expand Up @@ -460,18 +475,6 @@ def plan(
fail. Default: False
"""

@abc.abstractmethod
def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
"""Display the test result and output.

Args:
result: The unittest test result that contains metrics like num success, fails, ect.
output: The generated output from the unittest.
target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect.
"""

@abc.abstractmethod
def show_sql(self, sql: str) -> None:
"""Display to the user SQL."""
Expand Down Expand Up @@ -668,9 +671,7 @@ def plan(
if auto_apply:
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
pass

def show_sql(self, sql: str) -> None:
Expand Down Expand Up @@ -1952,10 +1953,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
):
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
divider_length = 70

self._log_test_details(result)
self._print("\n")

if result.wasSuccessful():
self._print("=" * divider_length)
self._print(
Expand All @@ -1972,9 +1975,13 @@ def log_test_results(
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"Failure Test: {test.model.name} {test.test_name}")
self._print(f"Failure Test: {test.path}::{test.test_name}")
self._print("=" * divider_length)
self._print(output)

def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
with self.console.capture() as capture:
self._log_test_details(result)
return strip_ansi_codes(capture.get())

def show_sql(self, sql: str) -> None:
self._print(Syntax(sql, "sql", word_wrap=True), crop=False)
Expand Down Expand Up @@ -2492,6 +2499,63 @@ def show_linter_violations(
else:
self.log_warning(msg)

def _log_test_details(self, result: ModelTextTestResult) -> None:
"""
This is a helper method that encapsulates the logic for logging the relevant unittest for the result.
The top level method (`log_test_results`) reuses `_log_test_details` differently based on the console.

Args:
result: The unittest test result that contains metrics like num success, fails, ect.
"""
tests_run = result.testsRun
errors = result.errors
failures = result.failures
skipped = result.skipped
is_success = not (errors or failures)

infos = []
if failures:
infos.append(f"failures={len(failures)}")
if errors:
infos.append(f"errors={len(errors)}")
if skipped:
infos.append(f"skipped={skipped}")

self._print("\n", end="")

for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
failures, result.failure_tables
):
self._print(unittest.TextTestResult.separator1)
self._print(f"FAIL: {test_case}")

if test_description := test_case.shortDescription():
self._print(test_description)
self._print(f"{unittest.TextTestResult.separator2}")

if not test_failure_tables:
self._print(failure)
else:
for failure_table in test_failure_tables:
self._print(failure_table)
self._print("\n", end="")

for test_case, error in errors:
self._print(unittest.TextTestResult.separator1)
self._print(f"ERROR: {test_case}")
self._print(f"{unittest.TextTestResult.separator2}")
self._print(error)

# Output final report
self._print(unittest.TextTestResult.separator2)
test_duration_msg = f" in {result.duration:.3f}s" if result.duration else ""
self._print(
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
)
self._print(
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
)


def _cells_match(x: t.Any, y: t.Any) -> bool:
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
Expand Down Expand Up @@ -2763,9 +2827,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
)
self.display(radio)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
import ipywidgets as widgets

divider_length = 70
Expand All @@ -2781,12 +2843,14 @@ def log_test_results(
h(
"span",
{"style": {**shared_style, **success_color}},
f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}",
f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}",
)
)
footer = str(h("span", {"style": shared_style}, "=" * divider_length))
self.display(widgets.HTML("<br>".join([header, message, footer])))
else:
output = self._captured_unit_test_results(result)

fail_color = {"color": "#db3737"}
fail_shared_style = {**shared_style, **fail_color}
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
Expand Down Expand Up @@ -3137,21 +3201,22 @@ def stop_promotion_progress(self, success: bool = True) -> None:
def log_success(self, message: str) -> None:
self._print(message)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
if result.wasSuccessful():
self._print(
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
)
else:
self._print("```")
self._log_test_details(result)
self._print("```\n\n")

self._print(
f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n"
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")
self._print(f"```{output}```\n\n")

def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
if snapshot_names:
Expand Down Expand Up @@ -3530,9 +3595,7 @@ def show_model_difference_summary(
for modified in context_diff.modified_snapshots:
self._write(f" Modified: {modified}")

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
self._write("Test Results:", result)

def show_sql(self, sql: str) -> None:
Expand Down
28 changes: 13 additions & 15 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import time
import traceback
import typing as t
import unittest.result
from functools import cached_property
from io import StringIO
from itertools import chain
Expand Down Expand Up @@ -2053,7 +2052,7 @@ def test(

test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)

return run_tests(
result = run_tests(
model_test_metadata=test_meta,
models=self._models,
config=self.config,
Expand All @@ -2066,6 +2065,13 @@ def test(
default_catalog_dialect=self.config.dialect or "",
)

self.console.log_test_results(
result,
self.test_connection_config._engine_adapter.DIALECT,
)

return result

@python_api_analytics
def audit(
self,
Expand Down Expand Up @@ -2488,28 +2494,20 @@ def import_state(self, input_file: Path, clear: bool = False, confirm: bool = Tr

def _run_tests(
self, verbosity: Verbosity = Verbosity.DEFAULT
) -> t.Tuple[unittest.result.TestResult, str]:
) -> t.Tuple[ModelTextTestResult, str]:
test_output_io = StringIO()
result = self.test(stream=test_output_io, verbosity=verbosity)
return result, test_output_io.getvalue()

def _run_plan_tests(
self, skip_tests: bool = False
) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]:
def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]:
if not skip_tests:
result, test_output = self._run_tests()
if result.testsRun > 0:
self.console.log_test_results(
result,
test_output,
self.test_connection_config._engine_adapter.DIALECT,
)
result = self.test()
if not result.wasSuccessful():
raise PlanError(
"Cannot generate plan due to failing test(s). Fix test(s) and run again."
)
return result, test_output
return None, None
return result
return None

@property
def _model_tables(self) -> t.Dict[str, str]:
Expand Down
Loading