Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 90 additions & 22 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
import logging
import textwrap
from itertools import zip_longest
from pathlib import Path
from hyperscript import h
from rich.console import Console as RichConsole
Expand All @@ -26,6 +27,7 @@
from rich.tree import Tree
from sqlglot import exp

from sqlmesh.core.test.result import ModelTextTestResult
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
from sqlmesh.core.linter.rule import RuleViolation
from sqlmesh.core.model import Model
Expand All @@ -46,6 +48,7 @@
NodeAuditsErrors,
format_destructive_change_msg,
)
from sqlmesh.utils.rich import strip_ansi_codes

if t.TYPE_CHECKING:
import ipywidgets as widgets
Expand Down Expand Up @@ -316,6 +319,12 @@ def log_destructive_change(
"""Display a destructive change error or warning to the user."""


class UnitTestConsole(abc.ABC):
@abc.abstractmethod
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
"""Display the test result and output."""


class Console(
PlanBuilderConsole,
LinterConsole,
Expand All @@ -327,6 +336,7 @@ class Console(
DifferenceConsole,
TableDiffConsole,
BaseConsole,
UnitTestConsole,
abc.ABC,
):
"""Abstract base class for defining classes used for displaying information to the user and also interact
Expand Down Expand Up @@ -461,9 +471,7 @@ def plan(
"""

@abc.abstractmethod
def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
"""Display the test result and output.
Comment thread
VaggelisD marked this conversation as resolved.
Outdated

Args:
Expand Down Expand Up @@ -496,6 +504,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
def loading_stop(self, id: uuid.UUID) -> None:
"""Stop loading for the given id."""

@abc.abstractmethod
def log_unit_test_results(self, result: ModelTextTestResult) -> None:
"""Print the unit test results."""

Comment thread
VaggelisD marked this conversation as resolved.
Outdated

class NoopConsole(Console):
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
Expand Down Expand Up @@ -668,9 +680,7 @@ def plan(
if auto_apply:
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
pass

def show_sql(self, sql: str) -> None:
Expand Down Expand Up @@ -777,6 +787,9 @@ def start_destroy(self) -> bool:
def stop_destroy(self, success: bool = True) -> None:
pass

def log_unit_test_results(self, result: ModelTextTestResult) -> None:
pass


def make_progress_bar(
message: str,
Expand Down Expand Up @@ -1952,10 +1965,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
):
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
divider_length = 70

self.log_unit_test_results(result)
self._print("\n")

if result.wasSuccessful():
self._print("=" * divider_length)
self._print(
Expand All @@ -1972,9 +1987,13 @@ def log_test_results(
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"Failure Test: {test.model.name} {test.test_name}")
self._print(f"Failure Test: {test.path}::{test.test_name}")
self._print("=" * divider_length)
self._print(output)

def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
with self.console.capture() as capture:
self.log_unit_test_results(result)
return strip_ansi_codes(capture.get())

def show_sql(self, sql: str) -> None:
self._print(Syntax(sql, "sql", word_wrap=True), crop=False)
Expand Down Expand Up @@ -2492,6 +2511,56 @@ def show_linter_violations(
else:
self.log_warning(msg)

def log_unit_test_results(self, result: ModelTextTestResult) -> None:
tests_run = result.testsRun
errors = result.errors
failures = result.failures
skipped = result.skipped
is_success = not (errors or failures)

infos = []
if failures:
infos.append(f"failures={len(failures)}")
if errors:
infos.append(f"errors={len(errors)}")
if skipped:
infos.append(f"skipped={skipped}")

self._print("\n", end="")

for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
failures, result.failure_tables
):
self._print(unittest.TextTestResult.separator1)
self._print(f"FAIL: {test_case}")

if test_description := test_case.shortDescription():
self._print(test_description)
self._print(f"{unittest.TextTestResult.separator2}")

if not test_failure_tables:
self._print(failure)
else:
for failure_table in test_failure_tables:
self._print(failure_table)
self._print("\n", end="")

for test_case, error in errors:
self._print(unittest.TextTestResult.separator1)
self._print(f"ERROR: {test_case}")
self._print(f"{unittest.TextTestResult.separator2}")
self._print(error)

# Output final report
self._print(unittest.TextTestResult.separator2)
test_duration_msg = f" in {result.duration:.3f}s" if result.duration else ""
self._print(
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
)
self._print(
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
)


def _cells_match(x: t.Any, y: t.Any) -> bool:
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
Expand Down Expand Up @@ -2763,9 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
)
self.display(radio)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
import ipywidgets as widgets

divider_length = 70
Expand All @@ -2781,12 +2848,14 @@ def log_test_results(
h(
"span",
{"style": {**shared_style, **success_color}},
f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}",
f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}",
)
)
footer = str(h("span", {"style": shared_style}, "=" * divider_length))
self.display(widgets.HTML("<br>".join([header, message, footer])))
else:
output = self._captured_unit_test_results(result)

fail_color = {"color": "#db3737"}
fail_shared_style = {**shared_style, **fail_color}
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
Expand Down Expand Up @@ -3137,21 +3206,22 @@ def stop_promotion_progress(self, success: bool = True) -> None:
def log_success(self, message: str) -> None:
self._print(message)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
if result.wasSuccessful():
self._print(
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
)
else:
self._print("```")
self.log_unit_test_results(result)
self._print("```\n\n")

self._print(
f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n"
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")
self._print(f"```{output}```\n\n")

def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
if snapshot_names:
Expand Down Expand Up @@ -3530,9 +3600,7 @@ def show_model_difference_summary(
for modified in context_diff.modified_snapshots:
self._write(f" Modified: {modified}")

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
) -> None:
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
self._write("Test Results:", result)

def show_sql(self, sql: str) -> None:
Expand Down
32 changes: 16 additions & 16 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import time
import traceback
import typing as t
import unittest.result
from functools import cached_property
from io import StringIO
from itertools import chain
Expand Down Expand Up @@ -2044,6 +2043,7 @@ def test(
verbosity: Verbosity = Verbosity.DEFAULT,
preserve_fixtures: bool = False,
stream: t.Optional[t.TextIO] = None,
log_results: bool = True,
Comment thread
VaggelisD marked this conversation as resolved.
Outdated
) -> ModelTextTestResult:
"""Discover and run model tests"""
if verbosity >= Verbosity.VERBOSE:
Expand All @@ -2053,7 +2053,7 @@ def test(

test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)

return run_tests(
result = run_tests(
model_test_metadata=test_meta,
models=self._models,
config=self.config,
Expand All @@ -2066,6 +2066,14 @@ def test(
default_catalog_dialect=self.config.dialect or "",
)

if log_results:
self.console.log_test_results(
result,
self.test_connection_config._engine_adapter.DIALECT,
)

return result

@python_api_analytics
def audit(
self,
Expand Down Expand Up @@ -2488,28 +2496,20 @@ def import_state(self, input_file: Path, clear: bool = False, confirm: bool = Tr

def _run_tests(
self, verbosity: Verbosity = Verbosity.DEFAULT
) -> t.Tuple[unittest.result.TestResult, str]:
) -> t.Tuple[ModelTextTestResult, str]:
test_output_io = StringIO()
result = self.test(stream=test_output_io, verbosity=verbosity)
result = self.test(stream=test_output_io, verbosity=verbosity, log_results=False)
return result, test_output_io.getvalue()

def _run_plan_tests(
self, skip_tests: bool = False
) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]:
def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]:
if not skip_tests:
result, test_output = self._run_tests()
if result.testsRun > 0:
self.console.log_test_results(
result,
test_output,
self.test_connection_config._engine_adapter.DIALECT,
)
result = self.test()
if not result.wasSuccessful():
raise PlanError(
"Cannot generate plan due to failing test(s). Fix test(s) and run again."
)
return result, test_output
return None, None
return result
return None

@property
def _model_tables(self) -> t.Dict[str, str]:
Expand Down
Loading