diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index fdac43c25..4ef60e40c 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -197,7 +197,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list lang_support = get_language_support(file_path) require_return = lang_support.language != Language.JAVA criteria = FunctionFilterCriteria(require_return=require_return) - functions[file_path] = lang_support.discover_functions(file_path, criteria) + source = file_path.read_text(encoding="utf-8") + functions[file_path] = lang_support.discover_functions(source, file_path, criteria) except Exception as e: logger.debug(f"Failed to discover functions in {file_path}: {e}") @@ -226,7 +227,9 @@ def get_functions_to_optimize( if optimize_all: logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all) console.rule() - functions = get_all_files_and_functions(Path(optimize_all), ignore_paths) + functions = get_all_files_and_functions( + Path(optimize_all), ignore_paths, tests_root=test_cfg.tests_root, module_root=module_root + ) elif replay_test: functions, trace_file_path = get_all_replay_test_functions( replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root @@ -317,7 +320,12 @@ def get_functions_to_optimize( logger.info("Finding all functions modified in the current git diff ...") console.rule() ph("cli-optimizing-git-diff") - functions = get_functions_within_git_diff(uncommitted_changes=False) + functions = get_functions_within_git_diff( + uncommitted_changes=False, + tests_root=test_cfg.tests_root, + ignore_paths=ignore_paths, + module_root=module_root, + ) filtered_modified_functions, functions_count = filter_functions( functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) @@ -326,9 +334,16 @@ def get_functions_to_optimize( return filtered_modified_functions, functions_count, trace_file_path -def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]: +def get_functions_within_git_diff( + uncommitted_changes: bool, + tests_root: Path | None = None, + ignore_paths: list[Path] | None = None, + module_root: Path | None = None, +) -> dict[Path, list[FunctionToOptimize]]: modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes) - return get_functions_within_lines(modified_lines) + return get_functions_within_lines( + modified_lines, tests_root=tests_root, ignore_paths=ignore_paths, module_root=module_root + ) def closest_matching_file_function_name( @@ -406,12 +421,20 @@ def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionT return get_functions_within_lines(modified_lines) -def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]: +def get_functions_within_lines( + modified_lines: dict[str, list[int]], + tests_root: Path | None = None, + ignore_paths: list[Path] | None = None, + module_root: Path | None = None, +) -> dict[Path, list[FunctionToOptimize]]: functions: dict[Path, list[FunctionToOptimize]] = {} for path_str, lines_in_file in modified_lines.items(): path = Path(path_str) if not path.exists(): continue + if tests_root is not None and module_root is not None: + if not filter_files_optimized(path, tests_root, ignore_paths or [], module_root): + continue all_functions = find_all_functions_in_file(path) functions[path] = [ func @@ -424,7 +447,11 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat def get_all_files_and_functions( - module_root_path: Path, ignore_paths: list[Path], language: Language | None = None + module_root_path: Path, + ignore_paths: list[Path], + language: Language | None = None, + tests_root: Path | None = None, + module_root: Path | None = None, ) -> dict[Path, list[FunctionToOptimize]]: """Get all optimizable functions from files in the module root. @@ -432,6 +459,8 @@ def get_all_files_and_functions( module_root_path: Root path to search for source files. ignore_paths: List of paths to ignore. language: Optional specific language to filter for. If None, includes all supported languages. + tests_root: Test root path for prefiltering files before reading (avoids unnecessary I/O). + module_root: Module root path for prefiltering files before reading. Returns: Dictionary mapping file paths to lists of FunctionToOptimize. @@ -439,6 +468,9 @@ def get_all_files_and_functions( """ functions: dict[Path, list[FunctionToOptimize]] = {} for file_path in get_files_for_language(module_root_path, ignore_paths, language): + if tests_root is not None and module_root is not None: + if not filter_files_optimized(file_path, tests_root, ignore_paths, module_root): + continue functions.update(find_all_functions_in_file(file_path).items()) # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time. # Helpful if an optimize-all run is stuck and we restart it. diff --git a/tests/test_discovery_prefilter.py b/tests/test_discovery_prefilter.py new file mode 100644 index 000000000..fe15370de --- /dev/null +++ b/tests/test_discovery_prefilter.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from codeflash.discovery.functions_to_optimize import ( + get_all_files_and_functions, + get_functions_within_lines, +) + + +def test_prefilter_skips_test_files(tmp_path: Path) -> None: + """Files in tests_root should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, test_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert test_file not in result + + +def test_prefilter_skips_ignored_paths(tmp_path: Path) -> None: + """Files in ignore_paths should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + ignored_dir = module_root / "vendor" + ignored_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + vendor_file = ignored_dir / "lib.py" + vendor_file.write_text("def helper():\n return 2\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, vendor_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[ignored_dir], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert vendor_file not in result + + +def test_prefilter_skips_files_outside_module_root(tmp_path: Path) -> None: + """Files outside module_root should be skipped before read_text() is called.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + other_dir = tmp_path / "other" + other_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + outside_file = other_dir / "stray.py" + outside_file.write_text("def stray():\n return 3\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file, outside_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert outside_file not in result + + +def test_prefilter_disabled_without_params(tmp_path: Path) -> None: + """Without tests_root/module_root, no prefiltering occurs (backward compat).""" + module_root = tmp_path / "src" + module_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files: + mock_get_files.return_value = [source_file] + result = get_all_files_and_functions(module_root, ignore_paths=[]) + + assert source_file in result + + +def test_prefilter_in_get_functions_within_lines(tmp_path: Path) -> None: + """get_functions_within_lines should skip test files when prefilter params are provided.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + modified_lines = { + str(source_file): [1, 2], + str(test_file): [1, 2], + } + + result = get_functions_within_lines( + modified_lines, tests_root=tests_root, ignore_paths=[], module_root=module_root + ) + + assert source_file in result + assert test_file not in result + + +def test_prefilter_avoids_reading_skipped_files(tmp_path: Path) -> None: + """Verify that find_all_functions_in_file is NOT called for prefiltered files (the core perf win).""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + test_file = tests_root / "test_app.py" + test_file.write_text("def test_compute():\n return True\n", encoding="utf-8") + + with ( + patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files, + patch("codeflash.discovery.functions_to_optimize.find_all_functions_in_file") as mock_find, + ): + mock_get_files.return_value = [source_file, test_file] + mock_find.return_value = {} + get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + # find_all_functions_in_file (which does read_text) should only be called for source_file + assert mock_find.call_count == 1 + mock_find.assert_called_once_with(source_file) + + +def test_prefilter_skips_submodule_paths(tmp_path: Path) -> None: + """Submodule paths should be skipped by prefilter.""" + module_root = tmp_path / "src" + module_root.mkdir() + tests_root = tmp_path / "tests" + tests_root.mkdir() + submodule_dir = module_root / "vendor_submodule" + submodule_dir.mkdir() + + source_file = module_root / "app.py" + source_file.write_text("def compute():\n return 1\n", encoding="utf-8") + + submodule_file = submodule_dir / "lib.py" + submodule_file.write_text("def helper():\n return 2\n", encoding="utf-8") + + with ( + patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files, + patch( + "codeflash.discovery.functions_to_optimize.ignored_submodule_paths", return_value=[submodule_dir] + ), + ): + mock_get_files.return_value = [source_file, submodule_file] + result = get_all_files_and_functions( + module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root + ) + + assert source_file in result + assert submodule_file not in result