diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 36bf9e113..d52d377f8 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -1,25 +1,44 @@ -import json +from __future__ import annotations + import pstats import sqlite3 from copy import copy from pathlib import Path +from typing import Any, TextIO from codeflash.cli_cmds.console import logger class ProfileStats(pstats.Stats): + # Attributes set by pstats.Stats.init() — stubs don't expose them + files: list[str] + stream: TextIO + top_level: set[tuple[str, int, str]] + total_calls: int + prim_calls: int + total_tt: float + max_name_len: int + fcn_list: list[tuple[str, int, str]] | None + sort_arg_dict: dict[str, tuple[Any, ...]] + all_callees: dict[tuple[str, int, str], dict[tuple[str, int, str], tuple[int, int, float, float]]] | None + stats: dict[tuple[str, int, str], tuple[int, int, int | float, int | float, dict[Any, Any]]] + def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None: assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}" self.trace_file_path = trace_file_path self.time_unit = time_unit logger.debug(hasattr(self, "create_stats")) - super().__init__(copy(self)) + super().__init__(copy(self)) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface def create_stats(self) -> None: self.con = sqlite3.connect(self.trace_file_path) cur = self.con.cursor() - pdata = cur.execute("SELECT * FROM pstats").fetchall() + pdata = cur.execute( + "SELECT filename, line_number, function, class_name," + " call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns" + " FROM pstats" + ).fetchall() self.con.close() time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit] self.stats = {} @@ -32,19 +51,7 @@ def create_stats(self) -> None: num_callers, total_time_ns, cumulative_time_ns, - callers, ) in pdata: - loaded_callers = json.loads(callers) - unmapped_callers = {} - for caller in loaded_callers: - caller_key = caller["key"] - if isinstance(caller_key, list): - caller_key = tuple(caller_key) - elif not isinstance(caller_key, tuple): - caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key) - unmapped_callers[caller_key] = caller["value"] - - # Create function key with class name if present (matching tracer.py format) function_name = f"{class_name}.{function}" if class_name else function self.stats[(filename, line_number, function_name)] = ( @@ -52,11 +59,10 @@ def create_stats(self) -> None: num_callers, total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns, cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns, - unmapped_callers, + {}, ) - def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 - # Copied from pstats.Stats.print_stats and modified to print the correct time unit + def print_stats(self, *amount: str | float) -> ProfileStats: for filename in self.files: print(filename, file=self.stream) if self.files: @@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 _width, list_ = self.get_print_list(amount) if list_: self.print_title() - for func in list_: - self.print_line(func) + for fn in list_: + self.print_line(fn) print(file=self.stream) print(file=self.stream) return self