diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index e272442e67..b73f2d576f 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -7,6 +7,7 @@ import uuid import logging import textwrap +from itertools import zip_longest from pathlib import Path from hyperscript import h from rich.console import Console as RichConsole @@ -26,6 +27,7 @@ from rich.tree import Tree from sqlglot import exp +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary from sqlmesh.core.linter.rule import RuleViolation from sqlmesh.core.model import Model @@ -46,6 +48,7 @@ NodeAuditsErrors, format_destructive_change_msg, ) +from sqlmesh.utils.rich import strip_ansi_codes if t.TYPE_CHECKING: import ipywidgets as widgets @@ -316,6 +319,17 @@ def log_destructive_change( """Display a destructive change error or warning to the user.""" +class UnitTestConsole(abc.ABC): + @abc.abstractmethod + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + """Display the test result and output. + + Args: + result: The unittest test result that contains metrics like num success, fails, ect. + target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect. + """ + + class Console( PlanBuilderConsole, LinterConsole, @@ -327,6 +341,7 @@ class Console( DifferenceConsole, TableDiffConsole, BaseConsole, + UnitTestConsole, abc.ABC, ): """Abstract base class for defining classes used for displaying information to the user and also interact @@ -460,18 +475,6 @@ def plan( fail. Default: False """ - @abc.abstractmethod - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: - """Display the test result and output. - - Args: - result: The unittest test result that contains metrics like num success, fails, ect. - output: The generated output from the unittest. - target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect. - """ - @abc.abstractmethod def show_sql(self, sql: str) -> None: """Display to the user SQL.""" @@ -668,9 +671,7 @@ def plan( if auto_apply: plan_builder.apply() - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: pass def show_sql(self, sql: str) -> None: @@ -1952,10 +1953,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None: ): plan_builder.apply() - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: divider_length = 70 + + self._log_test_details(result) + self._print("\n") + if result.wasSuccessful(): self._print("=" * divider_length) self._print( @@ -1972,9 +1975,13 @@ def log_test_results( ) for test, _ in result.failures + result.errors: if isinstance(test, ModelTest): - self._print(f"Failure Test: {test.model.name} {test.test_name}") + self._print(f"Failure Test: {test.path}::{test.test_name}") self._print("=" * divider_length) - self._print(output) + + def _captured_unit_test_results(self, result: ModelTextTestResult) -> str: + with self.console.capture() as capture: + self._log_test_details(result) + return strip_ansi_codes(capture.get()) def show_sql(self, sql: str) -> None: self._print(Syntax(sql, "sql", word_wrap=True), crop=False) @@ -2492,6 +2499,63 @@ def show_linter_violations( else: self.log_warning(msg) + def _log_test_details(self, result: ModelTextTestResult) -> None: + """ + This is a helper method that encapsulates the logic for logging the relevant unittest for the result. + The top level method (`log_test_results`) reuses `_log_test_details` differently based on the console. + + Args: + result: The unittest test result that contains metrics like num success, fails, ect. + """ + tests_run = result.testsRun + errors = result.errors + failures = result.failures + skipped = result.skipped + is_success = not (errors or failures) + + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + self._print("\n", end="") + + for (test_case, failure), test_failure_tables in zip_longest( # type: ignore + failures, result.failure_tables + ): + self._print(unittest.TextTestResult.separator1) + self._print(f"FAIL: {test_case}") + + if test_description := test_case.shortDescription(): + self._print(test_description) + self._print(f"{unittest.TextTestResult.separator2}") + + if not test_failure_tables: + self._print(failure) + else: + for failure_table in test_failure_tables: + self._print(failure_table) + self._print("\n", end="") + + for test_case, error in errors: + self._print(unittest.TextTestResult.separator1) + self._print(f"ERROR: {test_case}") + self._print(f"{unittest.TextTestResult.separator2}") + self._print(error) + + # Output final report + self._print(unittest.TextTestResult.separator2) + test_duration_msg = f" in {result.duration:.3f}s" if result.duration else "" + self._print( + f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n" + ) + self._print( + f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}" + ) + def _cells_match(x: t.Any, y: t.Any) -> bool: """Helper function to compare two cells and returns true if they're equal, handling array objects.""" @@ -2763,9 +2827,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None: ) self.display(radio) - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: import ipywidgets as widgets divider_length = 70 @@ -2781,12 +2843,14 @@ def log_test_results( h( "span", {"style": {**shared_style, **success_color}}, - f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}", + f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}", ) ) footer = str(h("span", {"style": shared_style}, "=" * divider_length)) self.display(widgets.HTML("
".join([header, message, footer]))) else: + output = self._captured_unit_test_results(result) + fail_color = {"color": "#db3737"} fail_shared_style = {**shared_style, **fail_color} header = str(h("span", {"style": fail_shared_style}, "-" * divider_length)) @@ -3137,21 +3201,22 @@ def stop_promotion_progress(self, success: bool = True) -> None: def log_success(self, message: str) -> None: self._print(message) - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: if result.wasSuccessful(): self._print( f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n" ) else: + self._print("```") + self._log_test_details(result) + self._print("```\n\n") + self._print( f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n" ) for test, _ in result.failures + result.errors: if isinstance(test, ModelTest): self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n") - self._print(f"```{output}```\n\n") def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: if snapshot_names: @@ -3530,9 +3595,7 @@ def show_model_difference_summary( for modified in context_diff.modified_snapshots: self._write(f" Modified: {modified}") - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: self._write("Test Results:", result) def show_sql(self, sql: str) -> None: diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 0450827d6e..73f69ab5df 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -40,7 +40,6 @@ import time import traceback import typing as t -import unittest.result from functools import cached_property from io import StringIO from itertools import chain @@ -2053,7 +2052,7 @@ def test( test_meta = self.load_model_tests(tests=tests, patterns=match_patterns) - return run_tests( + result = run_tests( model_test_metadata=test_meta, models=self._models, config=self.config, @@ -2066,6 +2065,13 @@ def test( default_catalog_dialect=self.config.dialect or "", ) + self.console.log_test_results( + result, + self.test_connection_config._engine_adapter.DIALECT, + ) + + return result + @python_api_analytics def audit( self, @@ -2488,28 +2494,20 @@ def import_state(self, input_file: Path, clear: bool = False, confirm: bool = Tr def _run_tests( self, verbosity: Verbosity = Verbosity.DEFAULT - ) -> t.Tuple[unittest.result.TestResult, str]: + ) -> t.Tuple[ModelTextTestResult, str]: test_output_io = StringIO() result = self.test(stream=test_output_io, verbosity=verbosity) return result, test_output_io.getvalue() - def _run_plan_tests( - self, skip_tests: bool = False - ) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]: + def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]: if not skip_tests: - result, test_output = self._run_tests() - if result.testsRun > 0: - self.console.log_test_results( - result, - test_output, - self.test_connection_config._engine_adapter.DIALECT, - ) + result = self.test() if not result.wasSuccessful(): raise PlanError( "Cannot generate plan due to failing test(s). Fix test(s) and run again." ) - return result, test_output - return None, None + return result + return None @property def _model_tables(self) -> t.Dict[str, str]: diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index e43b8b215c..8ceb8a4447 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys + import datetime import threading import typing as t @@ -10,6 +12,7 @@ from pathlib import Path from unittest.mock import patch + from io import StringIO from sqlglot import Dialect, exp from sqlglot.optimizer.annotate_types import annotate_types @@ -24,6 +27,8 @@ from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime from sqlmesh.utils.errors import ConfigError, TestError from sqlmesh.utils.yaml import load as yaml_load +from sqlmesh.utils import Verbosity +from sqlmesh.utils.rich import df_to_table if t.TYPE_CHECKING: import pandas as pd @@ -60,6 +65,7 @@ def __init__( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """ModelTest encapsulates a unit test for a model. @@ -83,6 +89,7 @@ def __init__( self.default_catalog = default_catalog self.dialect = dialect self.concurrency = concurrency + self.verbosity = verbosity self._fixture_table_cache: t.Dict[str, exp.Table] = {} self._normalized_column_name_cache: t.Dict[str, str] = {} @@ -134,6 +141,11 @@ def __init__( super().__init__() + def defaultTestResult(self) -> unittest.TestResult: + from sqlmesh.core.test.result import ModelTextTestResult + + return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity) + def shortDescription(self) -> t.Optional[str]: return self.body.get("description") @@ -290,23 +302,53 @@ def _to_hashable(x: t.Any) -> t.Any: check_like=True, # Ignore column order ) except AssertionError as e: + # There are 2 concepts at play here: + # 1. The Exception args will contain the error message plus the diff dataframe table stringified + # (backwards compatibility with existing tests, possible to serialize/send over network etc) + # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll + # be surfaced to the user through Console for better UX (versus stringified dataframes) + # + # This is a bit of a hack, but it's a way to get the best of both worlds. + args: t.List[t.Any] = [] if expected.shape != actual.shape: _raise_if_unexpected_columns(expected.columns, actual.columns) - error_msg = "Data mismatch (rows are different)" + args.append("Data mismatch (rows are different)") missing_rows = _row_difference(expected, actual) if not missing_rows.empty: - error_msg += f"\n\nMissing rows:\n\n{missing_rows}" + args[0] += f"\n\nMissing rows:\n\n{missing_rows}" + args.append(df_to_table("Missing rows", missing_rows)) unexpected_rows = _row_difference(actual, expected) + if not unexpected_rows.empty: - error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}" + args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}" + args.append(df_to_table("Unexpected rows", unexpected_rows)) - e.args = (error_msg,) else: diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) - e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",) + + args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}") + + diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True) + if self.verbosity == Verbosity.DEFAULT: + args.extend( + df_to_table("Data mismatch", df) for df in _split_df_by_column_pairs(diff) + ) + else: + from pandas import MultiIndex + + levels = t.cast(MultiIndex, diff.columns).levels[0] + for col in levels: + col_diff = diff[col] + if not col_diff.empty: + table = df_to_table( + f"[bold red]Column '{col}' mismatch[/bold red]", col_diff + ) + args.append(table) + + e.args = (*args,) raise e @@ -328,6 +370,7 @@ def create_test( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> t.Optional[ModelTest]: """Create a SqlModelTest or a PythonModelTest. @@ -373,6 +416,7 @@ def create_test( preserve_fixtures, default_catalog, concurrency, + verbosity, ) except Exception as e: raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}") @@ -692,6 +736,7 @@ def __init__( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """PythonModelTest encapsulates a unit test for a Python model. @@ -718,6 +763,7 @@ def __init__( preserve_fixtures, default_catalog, concurrency, + verbosity, ) self.context = TestExecutionContext( @@ -951,3 +997,43 @@ def _normalize_df_value(value: t.Any) -> t.Any: return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])} return {k: _normalize_df_value(v) for k, v in value.items()} return value + + +def _split_df_by_column_pairs(df: pd.DataFrame, pairs_per_chunk: int = 4) -> t.List[pd.DataFrame]: + """Split a dataframe into chunks of column pairs. + + Args: + df: The dataframe to split + pairs_per_chunk: Number of column pairs per chunk (default: 4) + + Returns: + List of dataframes, each containing an even number of columns + """ + total_columns = len(df.columns) + + # If we have fewer columns than pairs_per_chunk * 2, return the original df + if total_columns <= pairs_per_chunk * 2: + return [df] + + # Calculate number of chunks needed to split columns evenly + num_chunks = (total_columns + (pairs_per_chunk * 2 - 1)) // (pairs_per_chunk * 2) + + # Calculate columns per chunk to ensure equal distribution + # We round down to nearest even number to ensure each chunk has even columns + columns_per_chunk = (total_columns // num_chunks) & ~1 # Round down to nearest even number + remainder = total_columns - (columns_per_chunk * num_chunks) + + chunks = [] + start_idx = 0 + + # Distribute columns evenly across chunks + for i in range(num_chunks): + # Add 2 columns to early chunks if there's a remainder + # This ensures we always add pairs of columns + extra = 2 if i < remainder // 2 else 0 + end_idx = start_idx + columns_per_chunk + extra + chunk = df.iloc[:, start_idx:end_idx] + chunks.append(chunk) + start_idx = end_idx + + return chunks diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index cdba66b612..8621b8b10a 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -15,10 +15,13 @@ class ModelTextTestResult(unittest.TextTestResult): successes: t.List[unittest.TestCase] def __init__(self, *args: t.Any, **kwargs: t.Any): + self.console = kwargs.pop("console", None) super().__init__(*args, **kwargs) self.successes = [] self.original_failures: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] + self.failure_tables: t.List[t.Tuple[t.Any, ...]] = [] self.original_errors: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] + self.duration: t.Optional[float] = None def addSubTest( self, @@ -41,6 +44,12 @@ def addSubTest( super().addSubTest(test, subtest, err) + def _print_char(self, char: str) -> None: + from sqlmesh.core.console import TerminalConsole + + if isinstance(self.console, TerminalConsole): + self.console._print(char, end="") + def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: """Called when the test case test signals a failure. @@ -51,7 +60,18 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ exctype, value, _ = err + + if value and value.args: + exception_msg, rich_tables = value.args[:1], value.args[1:] + value.args = exception_msg + + if rich_tables: + self.failure_tables.append(rich_tables) + + self._print_char("F") + self.original_failures.append((test, err)) + # Intentionally ignore the traceback to hide it from the user return super().addFailure(test, (exctype, value, None)) # type: ignore @@ -64,6 +84,9 @@ def addError(self, test: unittest.TestCase, err: ErrorType) -> None: """ exctype, value, _ = err self.original_errors.append((test, err)) + + self._print_char("E") + # Intentionally ignore the traceback to hide it from the user return super().addError(test, (exctype, value, None)) # type: ignore @@ -74,52 +97,24 @@ def addSuccess(self, test: unittest.TestCase) -> None: test: The test case """ super().addSuccess(test) - self.successes.append(test) - def log_test_report(self, test_duration: float) -> None: - """ - Log the test report following unittest's conventions. + self._print_char(".") - Args: - test_duration: The duration of the tests. - """ - tests_run = self.testsRun - errors = self.errors - failures = self.failures - skipped = self.skipped - - is_success = not (errors or failures) - - infos = [] - if failures: - infos.append(f"failures={len(failures)}") - if errors: - infos.append(f"errors={len(errors)}") - if skipped: - infos.append(f"skipped={skipped}") - - stream = self.stream - - stream.write("\n") - - for test_case, failure in failures: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"FAIL: {test_case}") - if test_description := test_case.shortDescription(): - stream.writeln(test_description) - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln(failure) - - for test_case, error in errors: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"ERROR: {test_case}") - stream.writeln(error) - - # Output final report - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln( - f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n" - ) - stream.writeln( - f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}" - ) + self.successes.append(test) + + def merge(self, other: ModelTextTestResult) -> None: + if other.successes: + self.addSuccess(other.successes[0]) + elif other.errors: + for error_test, error in other.original_errors: + self.addError(error_test, error) + elif other.failures: + for failure_test, failure in other.original_failures: + self.addFailure(failure_test, failure) + + self.failure_tables.extend(other.failure_tables) + elif other.skipped: + skipped_args = other.skipped[0] + self.addSkip(skipped_args[0], skipped_args[1]) + + self.testsRun += 1 diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index d2a54d68e8..c098a46d84 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -1,10 +1,10 @@ from __future__ import annotations -import sys import time import threading import typing as t import unittest +from io import StringIO import concurrent from concurrent.futures import ThreadPoolExecutor @@ -16,7 +16,6 @@ ModelTestMetadata as ModelTestMetadata, ) from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig - from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult from sqlmesh.utils import UniqueKeyDict, Verbosity @@ -106,10 +105,13 @@ def run_tests( lock = threading.Lock() + from sqlmesh.core.console import get_console + combined_results = ModelTextTestResult( - stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore + stream=unittest.runner._WritelnDecorator(stream or StringIO()), # type: ignore verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, descriptions=True, + console=get_console(), ) metadata_to_adapter = create_testing_engine_adapters( @@ -136,6 +138,7 @@ def _run_single_test( default_catalog=default_catalog, preserve_fixtures=preserve_fixtures, concurrency=num_workers > 1, + verbosity=verbosity, ) if not test: @@ -147,19 +150,7 @@ def _run_single_test( ) with lock: - if result.successes: - combined_results.addSuccess(result.successes[0]) - elif result.errors: - for error_test, error in result.original_errors: - combined_results.addError(error_test, error) - elif result.failures: - for failure_test, failure in result.original_failures: - combined_results.addFailure(failure_test, failure) - elif result.skipped: - skipped_args = result.skipped[0] - combined_results.addSkip(skipped_args[0], skipped_args[1]) - - combined_results.testsRun += 1 + combined_results.merge(result) return result @@ -183,6 +174,6 @@ def _run_single_test( end_time = time.perf_counter() - combined_results.log_test_report(test_duration=end_time - start_time) + combined_results.duration = end_time - start_time return combined_results diff --git a/sqlmesh/integrations/github/cicd/command.py b/sqlmesh/integrations/github/cicd/command.py index b360f3366e..cedee1fa58 100644 --- a/sqlmesh/integrations/github/cicd/command.py +++ b/sqlmesh/integrations/github/cicd/command.py @@ -63,20 +63,19 @@ def check_required_approvers(ctx: click.Context) -> None: def _run_tests(controller: GithubController) -> bool: controller.update_test_check(status=GithubCheckStatus.IN_PROGRESS) try: - result, output = controller.run_tests() + result, _ = controller.run_tests() controller.update_test_check( status=GithubCheckStatus.COMPLETED, # Conclusion will be updated with final status based on test results conclusion=GithubCheckConclusion.NEUTRAL, result=result, - output=output, ) return result.wasSuccessful() except Exception: controller.update_test_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.FAILURE, - output=traceback.format_exc(), + traceback=traceback.format_exc(), ) return False diff --git a/sqlmesh/integrations/github/cicd/controller.py b/sqlmesh/integrations/github/cicd/controller.py index ad1b4b52b2..5a0ad36d71 100644 --- a/sqlmesh/integrations/github/cicd/controller.py +++ b/sqlmesh/integrations/github/cicd/controller.py @@ -8,7 +8,6 @@ import re import traceback import typing as t -import unittest from enum import Enum from typing import List from pathlib import Path @@ -20,6 +19,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import SNAPSHOT_CHANGE_CATEGORY_STR, get_console, MarkdownConsole from sqlmesh.core.context import Context +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.environment import Environment from sqlmesh.core.plan import Plan, PlanBuilder from sqlmesh.core.snapshot.definition import ( @@ -494,7 +494,7 @@ def get_plan_summary(self, plan: Plan) -> str: except PlanError as e: return f"Plan failed to generate. Check for pending or unresolved changes. Error: {e}" - def run_tests(self) -> t.Tuple[unittest.result.TestResult, str]: + def run_tests(self) -> t.Tuple[ModelTextTestResult, str]: """ Run tests for the PR """ @@ -734,8 +734,8 @@ def update_test_check( self, status: GithubCheckStatus, conclusion: t.Optional[GithubCheckConclusion] = None, - result: t.Optional[unittest.result.TestResult] = None, - output: t.Optional[str] = None, + result: t.Optional[ModelTextTestResult] = None, + traceback: t.Optional[str] = None, ) -> None: """ Updates the status of tests for code in the PR @@ -743,15 +743,13 @@ def update_test_check( def conclusion_handler( conclusion: GithubCheckConclusion, - result: t.Optional[unittest.result.TestResult], - output: t.Optional[str], + result: t.Optional[ModelTextTestResult], ) -> t.Tuple[GithubCheckConclusion, str, t.Optional[str]]: if result: # Clear out console self._console.consume_captured_output() self._console.log_test_results( result, - output, self._context.test_connection_config._engine_adapter.DIALECT, ) test_summary = self._console.consume_captured_output() @@ -762,8 +760,11 @@ def conclusion_handler( else GithubCheckConclusion.FAILURE ) return test_conclusion, test_title, test_summary + if traceback: + self._console._print(traceback) + test_title = "Skipped Tests" if conclusion.is_skipped else "Tests Failed" - return conclusion, test_title, output + return conclusion, test_title, traceback self._update_check_handler( check_name="SQLMesh - Run Unit Tests", @@ -776,7 +777,7 @@ def conclusion_handler( }[status], None, ), - conclusion_handler=functools.partial(conclusion_handler, result=result, output=output), + conclusion_handler=functools.partial(conclusion_handler, result=result), ) def update_required_approval_check( diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index e74019a743..95260170fe 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -1,5 +1,7 @@ from __future__ import annotations +from io import StringIO + import functools import logging import typing as t @@ -1032,6 +1034,7 @@ def run_test(self, context: Context, line: str) -> None: tests=args.tests, verbosity=Verbosity(args.verbose), preserve_fixtures=args.preserve_fixtures, + stream=StringIO(), # consume the output instead of redirecting to stdout ) @magic_arguments() diff --git a/sqlmesh/utils/rich.py b/sqlmesh/utils/rich.py index 6ebeab3114..589dd0b50f 100644 --- a/sqlmesh/utils/rich.py +++ b/sqlmesh/utils/rich.py @@ -1,8 +1,17 @@ +from __future__ import annotations + import typing as t +import re + from rich.console import Console from rich.progress import Column, ProgressColumn, Task, Text from rich.theme import Theme +from rich.table import Table +from rich.align import Align + +if t.TYPE_CHECKING: + import pandas as pd theme = Theme( { @@ -46,3 +55,50 @@ def render(self, task: Task) -> Text: f"{completed:{total_width}d}{self.separator}{total}", style="progress.download", ) + + +def strip_ansi_codes(text: str) -> str: + """Strip ANSI color codes and styling from text.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") + return ansi_escape.sub("", text).strip() + + +def df_to_table( + header: str, + df: pd.DataFrame, + show_index: bool = True, + index_name: str = "Row", +) -> Table: + """Convert a pandas.DataFrame obj into a rich.Table obj. + Args: + df (DataFrame): A Pandas DataFrame to be converted to a rich Table. + rich_table (Table): A rich Table that should be populated by the DataFrame values. + show_index (bool): Add a column with a row count to the table. Defaults to True. + index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value. + Returns: + Table: The rich Table instance passed, populated with the DataFrame values.""" + + rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60) + if show_index: + index_name = str(index_name) if index_name else "" + rich_table.add_column(Align.center(index_name)) + + for column in df.columns: + column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column) + + # Color coding unit test columns (expected/actual), can be removed or refactored if df_to_table is used elswhere too + lower = column_name.lower() + if "expected" in lower: + column_name = f"[green]{column_name}[/green]" + elif "actual" in lower: + column_name = f"[red]{column_name}[/red]" + + rich_table.add_column(Align.center(column_name)) + + for index, value_list in enumerate(df.values.tolist()): + row = [str(index)] if show_index else [] + row += [str(x) for x in value_list] + center = [Align.center(x) for x in row] + rich_table.add_row(*center) + + return rich_table diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 540a1384e2..11a2b46fb0 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -5,7 +5,7 @@ import pytest from sqlmesh.utils.metaprogramming import Executable -from tests.core.test_table_diff import create_test_console, strip_ansi_codes +from tests.core.test_table_diff import create_test_console import time_machine from pytest_mock.plugin import MockerFixture from sqlglot import parse_one @@ -42,6 +42,7 @@ yesterday_ds, ) from sqlmesh.utils.errors import PlanError +from sqlmesh.utils.rich import strip_ansi_codes def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture): diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index bf491d77a7..092357f88d 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -3,7 +3,6 @@ import pandas as pd # noqa: TID253 from sqlglot import exp from sqlmesh.core import dialect as d -import re import typing as t from io import StringIO from rich.console import Console @@ -14,6 +13,7 @@ from sqlmesh.core.table_diff import TableDiff, SchemaDiff import numpy as np # noqa: TID253 from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.rich import strip_ansi_codes pytestmark = pytest.mark.slow @@ -45,12 +45,6 @@ def capture_console_output(method_name: str, **kwargs) -> str: console_output.close() -def strip_ansi_codes(text: str) -> str: - """Strip ANSI color codes and styling from text.""" - ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") - return ansi_escape.sub("", text).strip() - - def test_data_diff(sushi_context_fixed_date, capsys, caplog): model = sushi_context_fixed_date.models['"memory"."sushi"."customer_revenue_by_day"'] diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 9478f4aa6b..7d65a818f1 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -6,7 +6,7 @@ from pathlib import Path import unittest from unittest.mock import call, patch -from shutil import copyfile +from shutil import copyfile, rmtree import pandas as pd # noqa: TID253 import pytest @@ -31,6 +31,7 @@ from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest from sqlmesh.core.test.result import ModelTextTestResult +from sqlmesh.utils import Verbosity from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError from sqlmesh.utils.yaml import dump as dump_yaml from sqlmesh.utils.yaml import load as load_yaml @@ -2218,6 +2219,7 @@ def test_test_with_resolve_template_macro(tmp_path: Path): _check_successful_or_raise(context.test()) +@use_terminal_console def test_test_output(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") @@ -2243,8 +2245,8 @@ def test_test_output(tmp_path: Path) -> None: rows: - item_id: 1 num_orders: 2 - - item_id: 2 - num_orders: 2 + - item_id: 4 + num_orders: 3 """ ) @@ -2255,40 +2257,130 @@ def test_test_output(tmp_path: Path) -> None: ) context = Context(paths=tmp_path, config=config) - # Case 1: Assert the log report is structured correctly - with capture_output() as output: + # Case 1: Ensure the log report is structured correctly + with capture_output() as captured_output: context.test() + output = captured_output.stdout + # Order may change due to concurrent execution - assert "F." in output.stderr or ".F" in output.stderr + assert "F." in output or ".F" in output assert ( - f"""====================================================================== -FAIL: test_example_full_model ({new_test_file}) -This is a test + f"""This is a test ---------------------------------------------------------------------- -AssertionError: Data mismatch (exp: expected, act: actual) - - num_orders - exp act -1 2.0 1.0 + Data mismatch +┏━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ +┃ ┃ item_id: ┃ ┃ num_orders: ┃ num_orders: ┃ +┃ Row ┃ Expected ┃ item_id: Actual ┃ Expected ┃ Actual ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 4.0 │ 2.0 │ 3.0 │ 1.0 │ +└─────┴─────────────────┴─────────────────┴─────────────────┴──────────────────┘ ----------------------------------------------------------------------""" - in output.stderr + in output ) - assert "Ran 2 tests" in output.stderr - assert "FAILED (failures=1)" in output.stderr + assert "Ran 2 tests" in output + assert "FAILED (failures=1)" in output - # Case 2: Assert that concurrent execution is working properly + # Case 2: Ensure that the verbose log report is structured correctly + with capture_output() as captured_output: + context.test(verbosity=Verbosity.VERBOSE) + + output = captured_output.stdout + + assert ( + f"""This is a test +---------------------------------------------------------------------- + Column 'item_id' mismatch +┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ Expected ┃ Actual ┃ +┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 4.0 │ 2.0 │ +└─────────────┴────────────────────────┴───────────────────┘ + + Column 'num_orders' mismatch +┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ Expected ┃ Actual ┃ +┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 3.0 │ 1.0 │ +└─────────────┴────────────────────────┴───────────────────┘ + +----------------------------------------------------------------------""" + in output + ) + + # Case 3: Assert that concurrent execution is working properly for i in range(50): copyfile(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml") copyfile(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml") - with capture_output() as output: + with capture_output() as captured_output: context.test() - assert "Ran 102 tests" in output.stderr - assert "FAILED (failures=51)" in output.stderr + output = captured_output.stdout + + assert "Ran 102 tests" in output + assert "FAILED (failures=51)" in output + + # Case 4: Test that wide tables are split into even chunks for default verbosity + rmtree(tmp_path / "tests") + + wide_model_query = ( + "SELECT 1 AS col_1, 2 AS col_2, 3 AS col_3, 4 AS col_4, 5 AS col_5, 6 AS col_6, 7 AS col_7" + ) + + context.upsert_model( + _create_model( + meta="MODEL(name test.test_wide_model)", + query=wide_model_query, + default_catalog=context.default_catalog, + ) + ) + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + + wide_test_file = tmp_path / "tests" / "test_wide_model.yaml" + wide_test_file_content = """ + test_wide_model: + model: test.test_wide_model + outputs: + query: + rows: + - col_1: 6 + col_2: 5 + col_3: 4 + col_4: 3 + col_5: 2 + col_6: 1 + col_7: 0 + + """ + + wide_test_file.write_text(wide_test_file_content) + + with capture_output() as captured_output: + context.test() + + assert ( + """Data mismatch +┏━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓ +┃ ┃ col_1: ┃ col_1: ┃ col_2: ┃ col_2: ┃ col_3: ┃ col_3: ┃ col_4: ┃ col_4: ┃ +┃ Row ┃ Expec… ┃ Actual ┃ Expec… ┃ Actual ┃ Expec… ┃ Actual ┃ Expect… ┃ Actual ┃ +┡━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩ +│ 0 │ 6 │ 1 │ 5 │ 2 │ 4 │ 3 │ 3 │ 4 │ +└─────┴────────┴────────┴────────┴────────┴────────┴────────┴─────────┴────────┘ + + Data mismatch +┏━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ ┃ col_5: ┃ col_5: ┃ col_6: ┃ col_6: ┃ col_7: ┃ col_7: ┃ +┃ Row ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ +┡━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ 0 │ 2 │ 5 │ 1 │ 6 │ 0 │ 7 │ +└─────┴───────────┴───────────┴───────────┴───────────┴───────────┴────────────┘""" + in captured_output.stdout + ) @use_terminal_console @@ -2330,15 +2422,15 @@ def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: with capture_output() as output: context.test() - assert ( - f"""Model '"invalid_model"' was not found at {wrong_test_file}""" - in mock_logger.call_args[0][0] - ) - assert ( - ".\n----------------------------------------------------------------------\nRan 1 test in" - in output.stderr - ) - assert "OK" in output.stderr + assert ( + f"""Model '"invalid_model"' was not found at {wrong_test_file}""" + in mock_logger.call_args[0][0] + ) + assert ( + ".\n----------------------------------------------------------------------\n\nRan 1 test in" + in output.stdout + ) + assert "OK" in output.stdout def test_number_of_tests_found(tmp_path: Path) -> None: @@ -2553,6 +2645,7 @@ def upstream_table_python(context, **kwargs): ) +@use_terminal_console @pytest.mark.parametrize("is_error", [True, False]) def test_model_test_text_result_reporting_no_traceback( sushi_context: Context, full_model_with_two_ctes: SqlModel, is_error: bool @@ -2596,10 +2689,10 @@ def test_model_test_text_result_reporting_no_traceback( else: result.addFailure(test, (e.__class__, e, e.__traceback__)) - result.log_test_report(0) + with capture_output() as captured_output: + get_console().log_test_results(result, "duckdb") - stream.seek(0) - output = stream.read() + output = captured_output.stdout # Make sure that the traceback is not printed assert "Traceback" not in output diff --git a/tests/integrations/github/cicd/test_github_commands.py b/tests/integrations/github/cicd/test_github_commands.py index 17cd9fc7b7..9a22a74974 100644 --- a/tests/integrations/github/cicd/test_github_commands.py +++ b/tests/integrations/github/cicd/test_github_commands.py @@ -9,6 +9,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.plan import Plan +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.user import User, UserRole from sqlmesh.integrations.github.cicd import command from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod @@ -448,11 +449,11 @@ def test_run_all_test_failed( github_client, bot_config=GithubCICDBotConfig(merge_method=MergeMethod.MERGE), ) - test_result = TestResult() + test_result = ModelTextTestResult(stream=None, descriptions=True, verbosity=0) test_result.testsRun += 1 - test_result.addFailure(TestCase(), (None, None, None)) + test_result.addFailure(TestCase(), (TestError, TestError("some error"), None)) controller._context._run_tests = mocker.MagicMock( - side_effect=lambda **kwargs: (test_result, "some error") + side_effect=lambda **kwargs: (test_result, "") ) controller._context.users = [ User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) @@ -474,15 +475,9 @@ def test_run_all_test_failed( assert GithubCheckConclusion(test_checks_runs[2]["conclusion"]).is_failure assert test_checks_runs[2]["output"]["title"] == "Tests Failed" assert ( - test_checks_runs[2]["output"]["summary"] - == """**Num Successful Tests: 0** - - -```some error``` - - -""" + """sqlmesh.utils.errors.TestError: some error""" in test_checks_runs[2]["output"]["summary"] ) + assert """**Num Successful Tests: 0**""" in test_checks_runs[2]["output"]["summary"] assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping prod_plan_preview_checks_runs = controller._check_run_mapping[ diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index ac9be5cc9c..a4d98e1963 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -707,20 +707,6 @@ def test_test(notebook, sushi_context): assert test_file.read_text() == """test_customer_revenue_by_day: TESTING\n""" -def test_run_test(notebook, sushi_context): - with capture_output() as output: - notebook.run_line_magic( - magic_name="run_test", - line=f"{sushi_context.path / 'tests' / 'test_customer_revenue_by_day.yaml'}::test_customer_revenue_by_day", - ) - - assert not output.stdout - # TODO: Does it make sense for this to go to stderr? - assert "Ran 1 test" in output.stderr - assert "OK" in output.stderr - assert not output.outputs - - @pytest.mark.slow def test_audit(notebook, loaded_sushi_context, convert_all_html_output_to_text): with capture_output() as output: diff --git a/web/server/api/endpoints/commands.py b/web/server/api/endpoints/commands.py index b879fb07cd..5db3c85d66 100644 --- a/web/server/api/endpoints/commands.py +++ b/web/server/api/endpoints/commands.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import io import typing as t from fastapi import APIRouter, Body, Depends, Request, Response @@ -142,12 +141,10 @@ async def test( context: Context = Depends(get_loaded_context), ) -> models.TestResult: """Run one or all model tests""" - test_output = io.StringIO() try: result = context.test( tests=[str(context.path / Path(test))] if test else None, verbosity=verbosity, - stream=test_output, ) except Exception: import traceback @@ -157,11 +154,7 @@ async def test( message="Unable to run tests", origin="API -> commands -> test", ) - context.console.log_test_results( - result, - test_output.getvalue(), - context.test_connection_config._engine_adapter.DIALECT, - ) + context.console.log_test_results(result, context.test_connection_config._engine_adapter.DIALECT) def _test_path(test: ModelTest) -> t.Optional[str]: if path := test.path_relative_to(context.path): diff --git a/web/server/console.py b/web/server/console.py index 6077c3fb9b..2cda0af697 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -3,7 +3,6 @@ import asyncio import json import typing as t -import unittest from fastapi.encoders import jsonable_encoder from sse_starlette.sse import ServerSentEvent from sqlmesh.core.snapshot.definition import Interval, Intervals @@ -12,6 +11,7 @@ from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo from sqlmesh.core.test import ModelTest +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp from web.server import models from web.server.exceptions import ApiException @@ -258,9 +258,7 @@ def log_event( ) ) - def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: if result.wasSuccessful(): self.log_event( event=models.EventName.TESTS, @@ -279,6 +277,8 @@ def log_test_results( details=details, ) ) + + output = self._captured_unit_test_results(result) self.log_event( event=models.EventName.TESTS, data=models.ReportTestsFailure(