diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index fdac43c25..d9f68bce8 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: from argparse import Namespace + from collections.abc import Generator from codeflash.models.models import CodeOptimizationContext from codeflash.verification.verification_utils import TestConfig @@ -51,6 +52,46 @@ class FunctionProperties: staticmethod_class_name: Optional[str] +# ============================================================================= +# Discovery-scoped file/AST cache +# ============================================================================= + +_active_discovery_cache: dict[Path, str] | None = None +_active_ast_cache: dict[Path, ast.Module] | None = None + + +@contextlib.contextmanager +def discovery_cache() -> Generator[None, None, None]: + global _active_discovery_cache, _active_ast_cache + _active_discovery_cache = {} + _active_ast_cache = {} + try: + yield + finally: + _active_discovery_cache = None + _active_ast_cache = None + + +def read_file_cached(file_path: Path) -> str: + if _active_discovery_cache is not None: + if file_path not in _active_discovery_cache: + _active_discovery_cache[file_path] = file_path.read_text(encoding="utf-8") + return _active_discovery_cache[file_path] + return file_path.read_text(encoding="utf-8") + + +def parse_ast_cached(file_path: Path, source: str | None = None) -> ast.Module: + if _active_ast_cache is not None: + if file_path not in _active_ast_cache: + if source is None: + source = read_file_cached(file_path) + _active_ast_cache[file_path] = ast.parse(source) + return _active_ast_cache[file_path] + if source is None: + source = read_file_cached(file_path) + return ast.parse(source) + + # ============================================================================= # Multi-language support helpers # ============================================================================= @@ -135,7 +176,9 @@ def get_files_for_language( return files -def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bool, str | None]: +def _is_js_ts_function_exported( + file_path: Path, function_name: str, source: str | None = None +) -> tuple[bool, str | None]: """Check if a JavaScript/TypeScript function is exported from its module. For JS/TS, functions that are not exported cannot be imported by tests, @@ -144,6 +187,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo Args: file_path: Path to the source file. function_name: Name of the function to check. + source: Pre-read file content. If None, reads from disk. Returns: Tuple of (is_exported, export_name). export_name may be 'default' for default exports. @@ -152,16 +196,16 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo from codeflash.languages.javascript.treesitter import get_analyzer_for_file try: - source = file_path.read_text(encoding="utf-8") + if source is None: + source = read_file_cached(file_path) analyzer = get_analyzer_for_file(file_path) return analyzer.is_function_exported(source, function_name) except Exception as e: logger.debug(f"Failed to check export status for {function_name}: {e}") - # Return True to avoid blocking in case of errors return True, None -def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: str) -> bool: +def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: str, source: str | None = None) -> bool: """Check if a JS/TS function exists in the file but is not exported. Returns True only if the function name is found as a defined function @@ -170,7 +214,8 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s from codeflash.languages.javascript.treesitter import get_analyzer_for_file try: - source = file_path.read_text(encoding="utf-8") + if source is None: + source = read_file_cached(file_path) analyzer = get_analyzer_for_file(file_path) all_funcs = analyzer.find_functions( source, include_methods=True, include_arrow_functions=True, require_name=True @@ -183,27 +228,6 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s return False -def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]: - """Find all optimizable functions using the language support abstraction. - - This function uses the registered language support for the file's language - to discover functions, then converts them to FunctionToOptimize instances. - """ - from codeflash.languages.base import FunctionFilterCriteria - - functions: dict[Path, list[FunctionToOptimize]] = {} - - try: - lang_support = get_language_support(file_path) - require_return = lang_support.language != Language.JAVA - criteria = FunctionFilterCriteria(require_return=require_return) - functions[file_path] = lang_support.discover_functions(file_path, criteria) - except Exception as e: - logger.debug(f"Failed to discover functions in {file_path}: {e}") - - return functions - - def get_functions_to_optimize( optimize_all: str | None, replay_test: list[Path] | None, @@ -221,12 +245,14 @@ def get_functions_to_optimize( functions: dict[Path, list[FunctionToOptimize]] trace_file_path: Path | None = None is_lsp = is_LSP_enabled() - with warnings.catch_warnings(): + with discovery_cache(), warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=SyntaxWarning) if optimize_all: logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all) console.rule() - functions = get_all_files_and_functions(Path(optimize_all), ignore_paths) + functions = get_all_files_and_functions( + Path(optimize_all), ignore_paths, tests_root=test_cfg.tests_root, module_root=module_root + ) elif replay_test: functions, trace_file_path = get_all_replay_test_functions( replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root @@ -236,6 +262,13 @@ def get_functions_to_optimize( console.rule() file = Path(file) if isinstance(file, str) else file functions = find_all_functions_in_file(file) + # Source already cached by find_all_functions_in_file above + _js_ts_source: str | None = None + if only_get_this_function is not None and is_language_supported(file): + _lang = get_language_support(file) + if _lang.language in (Language.JAVASCRIPT, Language.TYPESCRIPT): + with contextlib.suppress(Exception): + _js_ts_source = read_file_cached(file) if only_get_this_function is not None: split_function = only_get_this_function.split(".") if len(split_function) > 2: @@ -260,15 +293,13 @@ def get_functions_to_optimize( return functions, 0, None # For JS/TS: check if the function exists but is not exported - if is_language_supported(file): - lang_support = get_language_support(file) - if lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT): - if _is_js_ts_function_exists_but_not_exported(file, only_function_name): - exit_with_message( - f"Function '{only_function_name}' exists in {file} but is not exported.\n" - f"In JavaScript/TypeScript, only exported functions can be optimized.\n" - f"Add: export {{ {only_function_name} }}" - ) + if _js_ts_source is not None: + if _is_js_ts_function_exists_but_not_exported(file, only_function_name, source=_js_ts_source): + exit_with_message( + f"Function '{only_function_name}' exists in {file} but is not exported.\n" + f"In JavaScript/TypeScript, only exported functions can be optimized.\n" + f"Add: export {{ {only_function_name} }}" + ) found = closest_matching_file_function_name(only_get_this_function, functions) if found is not None: @@ -295,7 +326,7 @@ def get_functions_to_optimize( # It's a standalone function - check if the function is exported name_to_check = found_function.function_name - is_exported, _ = _is_js_ts_function_exported(file, name_to_check) + is_exported, _ = _is_js_ts_function_exported(file, name_to_check, source=_js_ts_source) if not is_exported: if found_function.parents: logger.debug( @@ -317,7 +348,12 @@ def get_functions_to_optimize( logger.info("Finding all functions modified in the current git diff ...") console.rule() ph("cli-optimizing-git-diff") - functions = get_functions_within_git_diff(uncommitted_changes=False) + functions = get_functions_within_git_diff( + uncommitted_changes=False, + tests_root=test_cfg.tests_root, + ignore_paths=ignore_paths, + module_root=module_root, + ) filtered_modified_functions, functions_count = filter_functions( functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) @@ -326,9 +362,16 @@ def get_functions_to_optimize( return filtered_modified_functions, functions_count, trace_file_path -def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]: +def get_functions_within_git_diff( + uncommitted_changes: bool, + tests_root: Path | None = None, + ignore_paths: list[Path] | None = None, + module_root: Path | None = None, +) -> dict[Path, list[FunctionToOptimize]]: modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes) - return get_functions_within_lines(modified_lines) + return get_functions_within_lines( + modified_lines, tests_root=tests_root, ignore_paths=ignore_paths, module_root=module_root + ) def closest_matching_file_function_name( @@ -406,12 +449,20 @@ def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionT return get_functions_within_lines(modified_lines) -def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]: +def get_functions_within_lines( + modified_lines: dict[str, list[int]], + tests_root: Path | None = None, + ignore_paths: list[Path] | None = None, + module_root: Path | None = None, +) -> dict[Path, list[FunctionToOptimize]]: functions: dict[Path, list[FunctionToOptimize]] = {} for path_str, lines_in_file in modified_lines.items(): path = Path(path_str) if not path.exists(): continue + if tests_root is not None and module_root is not None: + if not filter_files_optimized(path, tests_root, ignore_paths or [], module_root): + continue all_functions = find_all_functions_in_file(path) functions[path] = [ func @@ -424,7 +475,11 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat def get_all_files_and_functions( - module_root_path: Path, ignore_paths: list[Path], language: Language | None = None + module_root_path: Path, + ignore_paths: list[Path], + language: Language | None = None, + tests_root: Path | None = None, + module_root: Path | None = None, ) -> dict[Path, list[FunctionToOptimize]]: """Get all optimizable functions from files in the module root. @@ -432,6 +487,8 @@ def get_all_files_and_functions( module_root_path: Root path to search for source files. ignore_paths: List of paths to ignore. language: Optional specific language to filter for. If None, includes all supported languages. + tests_root: Test root path for prefiltering files before reading (avoids unnecessary I/O). + module_root: Module root path for prefiltering files before reading. Returns: Dictionary mapping file paths to lists of FunctionToOptimize. @@ -439,6 +496,9 @@ def get_all_files_and_functions( """ functions: dict[Path, list[FunctionToOptimize]] = {} for file_path in get_files_for_language(module_root_path, ignore_paths, language): + if tests_root is not None and module_root is not None: + if not filter_files_optimized(file_path, tests_root, ignore_paths, module_root): + continue functions.update(find_all_functions_in_file(file_path).items()) # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time. # Helpful if an optimize-all run is stuck and we restart it. @@ -457,7 +517,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt lang_support = get_language_support(file_path) require_return = lang_support.language != Language.JAVA criteria = FunctionFilterCriteria(require_return=require_return) - source = file_path.read_text(encoding="utf-8") + source = read_file_cached(file_path) return {file_path: lang_support.discover_functions(source, file_path, criteria)} except Exception as e: logger.debug(f"Failed to discover functions in {file_path}: {e}") @@ -474,21 +534,20 @@ def get_all_replay_test_functions( trace_file_path: Path | None = None for replay_test_file in replay_test: try: - with replay_test_file.open("r", encoding="utf8") as f: - tree = ast.parse(f.read()) - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if ( - isinstance(target, ast.Name) - and target.id == "trace_file_path" - and isinstance(node.value, ast.Constant) - and isinstance(node.value.value, str) - ): - trace_file_path = Path(node.value.value) - break - if trace_file_path: + tree = parse_ast_cached(replay_test_file) + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "trace_file_path" + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + ): + trace_file_path = Path(node.value.value) break + if trace_file_path: + break if trace_file_path: break except Exception as e: @@ -602,7 +661,7 @@ def _get_java_replay_test_functions( from codeflash.languages.registry import get_language_support lang_support = get_language_support(source_file) - source_code = source_file.read_text(encoding="utf-8") + source_code = read_file_cached(source_file) all_functions = lang_support.discover_functions(source_code, source_file) for func in all_functions: @@ -730,11 +789,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: def inspect_top_level_functions_or_methods( file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None ) -> FunctionProperties | None: - with file_name.open(encoding="utf8") as file: - try: - ast_module = ast.parse(file.read()) - except Exception: - return None + try: + ast_module = parse_ast_cached(file_name) + except Exception: + return None visitor = TopLevelFunctionOrMethodVisitor( file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no ) diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 36bf9e113..d52d377f8 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -1,25 +1,44 @@ -import json +from __future__ import annotations + import pstats import sqlite3 from copy import copy from pathlib import Path +from typing import Any, TextIO from codeflash.cli_cmds.console import logger class ProfileStats(pstats.Stats): + # Attributes set by pstats.Stats.init() — stubs don't expose them + files: list[str] + stream: TextIO + top_level: set[tuple[str, int, str]] + total_calls: int + prim_calls: int + total_tt: float + max_name_len: int + fcn_list: list[tuple[str, int, str]] | None + sort_arg_dict: dict[str, tuple[Any, ...]] + all_callees: dict[tuple[str, int, str], dict[tuple[str, int, str], tuple[int, int, float, float]]] | None + stats: dict[tuple[str, int, str], tuple[int, int, int | float, int | float, dict[Any, Any]]] + def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None: assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}" self.trace_file_path = trace_file_path self.time_unit = time_unit logger.debug(hasattr(self, "create_stats")) - super().__init__(copy(self)) + super().__init__(copy(self)) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface def create_stats(self) -> None: self.con = sqlite3.connect(self.trace_file_path) cur = self.con.cursor() - pdata = cur.execute("SELECT * FROM pstats").fetchall() + pdata = cur.execute( + "SELECT filename, line_number, function, class_name," + " call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns" + " FROM pstats" + ).fetchall() self.con.close() time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit] self.stats = {} @@ -32,19 +51,7 @@ def create_stats(self) -> None: num_callers, total_time_ns, cumulative_time_ns, - callers, ) in pdata: - loaded_callers = json.loads(callers) - unmapped_callers = {} - for caller in loaded_callers: - caller_key = caller["key"] - if isinstance(caller_key, list): - caller_key = tuple(caller_key) - elif not isinstance(caller_key, tuple): - caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key) - unmapped_callers[caller_key] = caller["value"] - - # Create function key with class name if present (matching tracer.py format) function_name = f"{class_name}.{function}" if class_name else function self.stats[(filename, line_number, function_name)] = ( @@ -52,11 +59,10 @@ def create_stats(self) -> None: num_callers, total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns, cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns, - unmapped_callers, + {}, ) - def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 - # Copied from pstats.Stats.print_stats and modified to print the correct time unit + def print_stats(self, *amount: str | float) -> ProfileStats: for filename in self.files: print(filename, file=self.stream) if self.files: @@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 _width, list_ = self.get_print_list(amount) if list_: self.print_title() - for func in list_: - self.print_line(func) + for fn in list_: + self.print_line(fn) print(file=self.stream) print(file=self.stream) return self diff --git a/tests/test_discovery_cache.py b/tests/test_discovery_cache.py new file mode 100644 index 000000000..6a7b9a8ae --- /dev/null +++ b/tests/test_discovery_cache.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import ast +import tempfile +from pathlib import Path +from unittest.mock import patch + +from codeflash.discovery.functions_to_optimize import ( + discovery_cache, + find_all_functions_in_file, + inspect_top_level_functions_or_methods, + parse_ast_cached, + read_file_cached, +) + + +def test_read_file_cached_without_context_manager(tmp_path: Path) -> None: + f = tmp_path / "sample.py" + f.write_text("x = 1\n", encoding="utf-8") + assert read_file_cached(f) == "x = 1\n" + + +def test_read_file_cached_returns_same_object_within_context(tmp_path: Path) -> None: + f = tmp_path / "sample.py" + f.write_text("x = 1\n", encoding="utf-8") + with discovery_cache(): + result1 = read_file_cached(f) + result2 = read_file_cached(f) + assert result1 is result2 + + +def test_read_file_cached_does_not_persist_across_contexts(tmp_path: Path) -> None: + f = tmp_path / "sample.py" + f.write_text("x = 1\n", encoding="utf-8") + with discovery_cache(): + result1 = read_file_cached(f) + f.write_text("x = 2\n", encoding="utf-8") + with discovery_cache(): + result2 = read_file_cached(f) + assert result1 != result2 + + +def test_parse_ast_cached_returns_same_object_within_context(tmp_path: Path) -> None: + f = tmp_path / "sample.py" + f.write_text("def foo():\n return 1\n", encoding="utf-8") + with discovery_cache(): + tree1 = parse_ast_cached(f) + tree2 = parse_ast_cached(f) + assert tree1 is tree2 + assert isinstance(tree1, ast.Module) + + +def test_parse_ast_cached_uses_provided_source(tmp_path: Path) -> None: + f = tmp_path / "sample.py" + f.write_text("x = 1\n", encoding="utf-8") + source = "y = 2\n" + with discovery_cache(): + tree = parse_ast_cached(f, source=source) + assert any( + isinstance(n, ast.Assign) + and isinstance(n.targets[0], ast.Name) + and n.targets[0].id == "y" + for n in ast.walk(tree) + ) + + +def test_discovery_cache_avoids_redundant_reads(tmp_path: Path) -> None: + f = tmp_path / "module.py" + f.write_text("def bar():\n return 42\n", encoding="utf-8") + with discovery_cache(): + with patch.object(Path, "read_text", wraps=f.read_text) as mock_read: + read_file_cached(f) + read_file_cached(f) + read_file_cached(f) + assert mock_read.call_count == 1 + + +def test_find_all_functions_in_file_uses_cache(tmp_path: Path) -> None: + f = tmp_path / "module.py" + f.write_text("def compute(x):\n return x * 2\n", encoding="utf-8") + with discovery_cache(): + result = find_all_functions_in_file(f) + assert f in result + assert result[f][0].function_name == "compute" + + +def test_inspect_top_level_functions_uses_cache(tmp_path: Path) -> None: + f = tmp_path / "module.py" + f.write_text("def top_func(a, b):\n return a + b\n", encoding="utf-8") + with discovery_cache(): + props = inspect_top_level_functions_or_methods(f, "top_func") + assert props is not None + assert props.is_top_level + assert props.has_args + + +def test_find_and_inspect_share_cached_content(tmp_path: Path) -> None: + f = tmp_path / "module.py" + f.write_text( + "class MyClass:\n def method(self):\n return 1\n\ndef standalone():\n return 2\n", + encoding="utf-8", + ) + with discovery_cache(): + with patch.object(Path, "read_text", wraps=f.read_text) as mock_read: + find_all_functions_in_file(f) + props = inspect_top_level_functions_or_methods(f, "method", class_name="MyClass") + assert mock_read.call_count == 1 + assert props is not None + assert props.is_top_level + + +def test_discovery_results_correct_with_multiple_files(tmp_path: Path) -> None: + f1 = tmp_path / "a.py" + f1.write_text("def alpha():\n return 'a'\n", encoding="utf-8") + f2 = tmp_path / "b.py" + f2.write_text("def beta(x):\n return x + 1\n", encoding="utf-8") + + with discovery_cache(): + r1 = find_all_functions_in_file(f1) + r2 = find_all_functions_in_file(f2) + + assert r1[f1][0].function_name == "alpha" + assert r2[f2][0].function_name == "beta" + + +def test_cache_handles_invalid_syntax_gracefully(tmp_path: Path) -> None: + f = tmp_path / "broken.py" + f.write_text("def incomplete(:\n", encoding="utf-8") + with discovery_cache(): + result = find_all_functions_in_file(f) + assert result == {} + + +def test_cache_handles_nonexistent_file_in_parse_ast(tmp_path: Path) -> None: + f = tmp_path / "nonexistent.py" + with discovery_cache(): + try: + parse_ast_cached(f) + assert False, "Should have raised" + except (FileNotFoundError, OSError): + pass diff --git a/tests/test_discovery_prefilter.py b/tests/test_discovery_prefilter.py new file mode 100644 index 000000000..fe15370de --- /dev/null +++ b/tests/test_discovery_prefilter.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from codeflash.discovery.functions_to_optimize import ( + get_all_files_and_functions, + get_functions_within_lines, +) + + +def test_prefilter_skips_test_files(tmp_path: Path) -> None: + """Files in tests_root should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, test_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert test_file not in result + + +def test_prefilter_skips_ignored_paths(tmp_path: Path) -> None: + """Files in ignore_paths should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + ignored_dir = module_root / "vendor" + ignored_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + vendor_file = ignored_dir / "lib.py" + vendor_file.write_text("def helper():\n return 2\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, vendor_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[ignored_dir], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert vendor_file not in result + + +def test_prefilter_skips_files_outside_module_root(tmp_path: Path) -> None: + """Files outside module_root should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + other_dir = tmp_path / "other" + other_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + outside_file = other_dir / "stray.py" + outside_file.write_text("def stray():\n return 3\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, outside_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert outside_file not in result + + +def test_prefilter_disabled_without_params(tmp_path: Path) -> None: + """Without tests_root/module_root, no prefiltering occurs (backward compat).""" + module_root = tmp_path / "src" + module_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file] + result = get_all_files_and_functions(module_root, ignore_paths=[]) + + assert source_file in result + + +def test_prefilter_in_get_functions_within_lines(tmp_path: Path) -> None: + """get_functions_within_lines should skip test files when prefilter params are provided.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + modified_lines = { + str(source_file): [1, 2], + str(test_file): [1, 2], + } + + result = get_functions_within_lines( + modified_lines, tests_root=tests_root, ignore_paths=[], module_root=module_root + ) + + assert source_file in result + assert test_file not in result + + +def test_prefilter_avoids_reading_skipped_files(tmp_path: Path) -> None: + """Verify that find_all_functions_in_file is NOT called for prefiltered files (the core perf win).""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + with ( + patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files, + patch("codeflash.discovery.functions_to_optimize.find_all_functions_in_file") as mock_find, + ): + mock_get_files.return_value = [source_file, test_file] + mock_find.return_value = {} + get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + # find_all_functions_in_file (which does read_text) should only be called for source_file + assert mock_find.call_count == 1 + mock_find.assert_called_once_with(source_file) + + +def test_prefilter_skips_submodule_paths(tmp_path: Path) -> None: + """Submodule paths should be skipped by prefilter.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + submodule_dir = module_root / "vendor_submodule" + submodule_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + submodule_file = submodule_dir / "lib.py" + submodule_file.write_text("def helper():\n return 2\n", encoding="utf-8") + + with ( + patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files, + patch( + "codeflash.discovery.functions_to_optimize.ignored_submodule_paths", return_value=[submodule_dir] + ), + ): + mock_get_files.return_value = [source_file, submodule_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert submodule_file not in result diff --git a/tests/test_js_ts_export_helpers.py b/tests/test_js_ts_export_helpers.py new file mode 100644 index 000000000..27546ed4b --- /dev/null +++ b/tests/test_js_ts_export_helpers.py @@ -0,0 +1,141 @@ +"""Tests for JS/TS export helper functions with pre-read source content. + +Verifies that _is_js_ts_function_exported and _is_js_ts_function_exists_but_not_exported +work correctly both with and without pre-read source content passed in. +""" + +from pathlib import Path + +import pytest + +from codeflash.discovery.functions_to_optimize import ( + _is_js_ts_function_exists_but_not_exported, + _is_js_ts_function_exported, +) + + +class TestIsJsTsFunctionExported: + def test_named_export_detected(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + js_file.write_text( + "export function add(a, b) {\n return a + b;\n}\n", + encoding="utf-8", + ) + is_exported, export_name = _is_js_ts_function_exported(js_file, "add") + assert is_exported is True + assert export_name == "add" + + def test_named_export_with_source_param(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "export function add(a, b) {\n return a + b;\n}\n" + js_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(js_file, "add", source=content) + assert is_exported is True + assert export_name == "add" + + def test_default_export_detected(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "function compute(x) {\n return x * 2;\n}\nexport default compute;\n" + js_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(js_file, "compute", source=content) + assert is_exported is True + assert export_name == "default" + + def test_non_exported_function(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "function helper(x) {\n return x + 1;\n}\n\nexport function main() {\n return helper(5);\n}\n" + js_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(js_file, "helper", source=content) + assert is_exported is False + assert export_name is None + + def test_separate_export_clause(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "function process(data) {\n return data;\n}\n\nexport { process };\n" + js_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(js_file, "process", source=content) + assert is_exported is True + assert export_name == "process" + + def test_aliased_export(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "function internalName(x) {\n return x;\n}\n\nexport { internalName as publicName };\n" + js_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(js_file, "internalName", source=content) + assert is_exported is True + assert export_name == "publicName" + + def test_typescript_export(self, tmp_path: Path) -> None: + ts_file = tmp_path / "module.ts" + content = "export function greet(name: string): string {\n return `Hello, ${name}`;\n}\n" + ts_file.write_text(content, encoding="utf-8") + is_exported, export_name = _is_js_ts_function_exported(ts_file, "greet", source=content) + assert is_exported is True + assert export_name == "greet" + + def test_fallback_reads_from_disk_when_source_is_none(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + js_file.write_text( + "export function fromDisk(x) {\n return x;\n}\n", + encoding="utf-8", + ) + is_exported, export_name = _is_js_ts_function_exported(js_file, "fromDisk", source=None) + assert is_exported is True + assert export_name == "fromDisk" + + +class TestIsJsTsFunctionExistsButNotExported: + def test_unexported_function_detected(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "function secret(x) {\n return x * 2;\n}\n\nexport function pub() {\n return 1;\n}\n" + js_file.write_text(content, encoding="utf-8") + assert _is_js_ts_function_exists_but_not_exported(js_file, "secret", source=content) is True + + def test_exported_function_returns_false(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "export function pub(x) {\n return x;\n}\n" + js_file.write_text(content, encoding="utf-8") + assert _is_js_ts_function_exists_but_not_exported(js_file, "pub", source=content) is False + + def test_nonexistent_function_returns_false(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + content = "export function exists() {\n return 1;\n}\n" + js_file.write_text(content, encoding="utf-8") + assert _is_js_ts_function_exists_but_not_exported(js_file, "nonexistent", source=content) is False + + def test_fallback_reads_from_disk(self, tmp_path: Path) -> None: + js_file = tmp_path / "module.js" + js_file.write_text( + "function localOnly() {\n return 42;\n}\n", + encoding="utf-8", + ) + assert _is_js_ts_function_exists_but_not_exported(js_file, "localOnly", source=None) is True + + def test_typescript_unexported(self, tmp_path: Path) -> None: + ts_file = tmp_path / "utils.ts" + content = ( + "function internal(x: number): number {\n return x;\n}\n\n" + "export function external(y: number): number {\n return y + 1;\n}\n" + ) + ts_file.write_text(content, encoding="utf-8") + assert _is_js_ts_function_exists_but_not_exported(ts_file, "internal", source=content) is True + assert _is_js_ts_function_exists_but_not_exported(ts_file, "external", source=content) is False + + def test_arrow_function_unexported(self, tmp_path: Path) -> None: + js_file = tmp_path / "arrows.js" + content = "const helper = (x) => {\n return x + 1;\n};\n\nexport const main = () => {\n return 2;\n};\n" + js_file.write_text(content, encoding="utf-8") + assert _is_js_ts_function_exists_but_not_exported(js_file, "helper", source=content) is True + assert _is_js_ts_function_exists_but_not_exported(js_file, "main", source=content) is False + + def test_source_param_matches_disk_read(self, tmp_path: Path) -> None: + js_file = tmp_path / "consistent.js" + content = "function local() {\n return 'hi';\n}\n\nexport function exported() {\n return 'bye';\n}\n" + js_file.write_text(content, encoding="utf-8") + # Results should be identical whether source is passed or read from disk + assert _is_js_ts_function_exists_but_not_exported(js_file, "local", source=content) == ( + _is_js_ts_function_exists_but_not_exported(js_file, "local", source=None) + ) + assert _is_js_ts_function_exists_but_not_exported(js_file, "exported", source=content) == ( + _is_js_ts_function_exists_but_not_exported(js_file, "exported", source=None) + )