|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Multi-trial evaluation runner with pass@k / pass^k metrics. |
| 16 | +
|
| 17 | +Wraps any ``BigQueryTraceEvaluator`` to run N trials per task and |
| 18 | +compute probabilistic pass-rate metrics that account for agent |
| 19 | +non-determinism. |
| 20 | +
|
| 21 | +Example usage:: |
| 22 | +
|
| 23 | + from bigquery_agent_analytics import ( |
| 24 | + BigQueryTraceEvaluator, TrialRunner, |
| 25 | + ) |
| 26 | +
|
| 27 | + evaluator = BigQueryTraceEvaluator( |
| 28 | + project_id="my-project", |
| 29 | + dataset_id="analytics", |
| 30 | + ) |
| 31 | + runner = TrialRunner(evaluator, num_trials=5) |
| 32 | +
|
| 33 | + report = await runner.run_trials( |
| 34 | + session_id="sess-123", |
| 35 | + golden_trajectory=[{"tool_name": "search", "args": {}}], |
| 36 | + ) |
| 37 | + print(report.pass_at_k, report.pass_pow_k) |
| 38 | +""" |
| 39 | + |
| 40 | +from __future__ import annotations |
| 41 | + |
| 42 | +import asyncio |
| 43 | +import logging |
| 44 | +import math |
| 45 | +import statistics |
| 46 | +from typing import Any, Optional |
| 47 | + |
| 48 | +from pydantic import BaseModel |
| 49 | +from pydantic import Field |
| 50 | + |
| 51 | +from .trace_evaluator import BigQueryTraceEvaluator |
| 52 | +from .trace_evaluator import EvalStatus |
| 53 | +from .trace_evaluator import MatchType |
| 54 | + |
| 55 | +logger = logging.getLogger("bigquery_agent_analytics." + __name__) |
| 56 | + |
| 57 | + |
| 58 | +# ------------------------------------------------------------------ # |
| 59 | +# Data Models # |
| 60 | +# ------------------------------------------------------------------ # |
| 61 | + |
| 62 | + |
| 63 | +class TrialResult(BaseModel): |
| 64 | + """Result of a single trial.""" |
| 65 | + |
| 66 | + trial_index: int = Field(description="Zero-based trial index.") |
| 67 | + passed: bool = Field(description="Whether this trial passed.") |
| 68 | + scores: dict[str, float] = Field( |
| 69 | + default_factory=dict, |
| 70 | + description="Metric scores for this trial.", |
| 71 | + ) |
| 72 | + details: dict[str, Any] = Field( |
| 73 | + default_factory=dict, |
| 74 | + description="Additional trial details.", |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +class MultiTrialReport(BaseModel): |
| 79 | + """Aggregate report across N trials of one task.""" |
| 80 | + |
| 81 | + session_id: str = Field(description="The session ID evaluated.") |
| 82 | + num_trials: int = Field(description="Number of trials run.") |
| 83 | + trial_results: list[TrialResult] = Field( |
| 84 | + default_factory=list, |
| 85 | + description="Individual trial results.", |
| 86 | + ) |
| 87 | + pass_at_k: float = Field( |
| 88 | + default=0.0, |
| 89 | + description="P(>=1 pass in k trials).", |
| 90 | + ) |
| 91 | + pass_pow_k: float = Field( |
| 92 | + default=0.0, |
| 93 | + description="P(all k trials pass).", |
| 94 | + ) |
| 95 | + per_trial_pass_rate: float = Field( |
| 96 | + default=0.0, |
| 97 | + description="Fraction of trials that passed.", |
| 98 | + ) |
| 99 | + mean_scores: dict[str, float] = Field( |
| 100 | + default_factory=dict, |
| 101 | + description="Mean score per metric across trials.", |
| 102 | + ) |
| 103 | + score_std_dev: dict[str, float] = Field( |
| 104 | + default_factory=dict, |
| 105 | + description="Standard deviation per metric across trials.", |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +# ------------------------------------------------------------------ # |
| 110 | +# Static Helpers # |
| 111 | +# ------------------------------------------------------------------ # |
| 112 | + |
| 113 | + |
| 114 | +def compute_pass_at_k( |
| 115 | + num_trials: int, |
| 116 | + num_passed: int, |
| 117 | +) -> float: |
| 118 | + """Computes pass@k: P(>=1 pass in k trials). |
| 119 | +
|
| 120 | + Uses the formula: 1 - C(n-c, k) / C(n, k) |
| 121 | + where n = num_trials, c = num_passed, k = num_trials. |
| 122 | +
|
| 123 | + Args: |
| 124 | + num_trials: Total number of trials (k). |
| 125 | + num_passed: Number of trials that passed (c). |
| 126 | +
|
| 127 | + Returns: |
| 128 | + Probability that at least one trial passes. |
| 129 | + """ |
| 130 | + if num_trials <= 0: |
| 131 | + return 0.0 |
| 132 | + if num_passed <= 0: |
| 133 | + return 0.0 |
| 134 | + if num_passed >= num_trials: |
| 135 | + return 1.0 |
| 136 | + |
| 137 | + # 1 - C(n-c, k) / C(n, k) |
| 138 | + n = num_trials |
| 139 | + c = num_passed |
| 140 | + k = num_trials |
| 141 | + |
| 142 | + # C(n-c, k) / C(n, k) -- if n-c < k then C(n-c,k)=0 => pass@k=1 |
| 143 | + if n - c < k: |
| 144 | + return 1.0 |
| 145 | + |
| 146 | + # Use log to avoid overflow for large values |
| 147 | + log_numerator = sum(math.log(n - c - i) for i in range(k)) |
| 148 | + log_denominator = sum(math.log(n - i) for i in range(k)) |
| 149 | + |
| 150 | + return 1.0 - math.exp(log_numerator - log_denominator) |
| 151 | + |
| 152 | + |
| 153 | +def compute_pass_pow_k( |
| 154 | + num_trials: int, |
| 155 | + num_passed: int, |
| 156 | +) -> float: |
| 157 | + """Computes pass^k: P(all k trials pass). |
| 158 | +
|
| 159 | + Uses the formula: (num_passed / num_trials) ** num_trials. |
| 160 | +
|
| 161 | + Args: |
| 162 | + num_trials: Total number of trials. |
| 163 | + num_passed: Number of trials that passed. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + Probability that all trials pass. |
| 167 | + """ |
| 168 | + if num_trials <= 0: |
| 169 | + return 0.0 |
| 170 | + if num_passed <= 0: |
| 171 | + return 0.0 |
| 172 | + rate = num_passed / num_trials |
| 173 | + return rate**num_trials |
| 174 | + |
| 175 | + |
| 176 | +# ------------------------------------------------------------------ # |
| 177 | +# TrialRunner # |
| 178 | +# ------------------------------------------------------------------ # |
| 179 | + |
| 180 | + |
| 181 | +class TrialRunner: |
| 182 | + """Runs multiple evaluation trials and computes aggregate metrics. |
| 183 | +
|
| 184 | + Wraps a ``BigQueryTraceEvaluator`` and runs N trials per task, |
| 185 | + computing pass@k and pass^k metrics that account for agent |
| 186 | + non-determinism (e.g. LLM judges produce different scores each |
| 187 | + call). |
| 188 | +
|
| 189 | + Example:: |
| 190 | +
|
| 191 | + runner = TrialRunner(evaluator, num_trials=5, concurrency=3) |
| 192 | + report = await runner.run_trials( |
| 193 | + session_id="sess-123", |
| 194 | + golden_trajectory=[...], |
| 195 | + ) |
| 196 | + """ |
| 197 | + |
| 198 | + def __init__( |
| 199 | + self, |
| 200 | + evaluator: BigQueryTraceEvaluator, |
| 201 | + num_trials: int = 5, |
| 202 | + concurrency: int = 3, |
| 203 | + ) -> None: |
| 204 | + """Initializes the TrialRunner. |
| 205 | +
|
| 206 | + Args: |
| 207 | + evaluator: The trace evaluator to wrap. |
| 208 | + num_trials: Number of trials to run per task. |
| 209 | + concurrency: Maximum concurrent evaluations. |
| 210 | + """ |
| 211 | + self.evaluator = evaluator |
| 212 | + self.num_trials = num_trials |
| 213 | + self.concurrency = concurrency |
| 214 | + |
| 215 | + async def run_trials( |
| 216 | + self, |
| 217 | + session_id: str, |
| 218 | + golden_trajectory: Optional[list[dict]] = None, |
| 219 | + golden_response: Optional[str] = None, |
| 220 | + match_type: MatchType = MatchType.EXACT, |
| 221 | + task_description: Optional[str] = None, |
| 222 | + use_llm_judge: bool = False, |
| 223 | + thresholds: Optional[dict[str, float]] = None, |
| 224 | + ) -> MultiTrialReport: |
| 225 | + """Runs N trials of evaluation for a single session. |
| 226 | +
|
| 227 | + Args: |
| 228 | + session_id: The session ID to evaluate. |
| 229 | + golden_trajectory: Expected tool call sequence. |
| 230 | + golden_response: Expected final response. |
| 231 | + match_type: Type of trajectory matching. |
| 232 | + task_description: Task description for LLM judge. |
| 233 | + use_llm_judge: Whether to use LLM-as-judge. |
| 234 | + thresholds: Metric thresholds for pass/fail. |
| 235 | +
|
| 236 | + Returns: |
| 237 | + MultiTrialReport with aggregate metrics. |
| 238 | + """ |
| 239 | + semaphore = asyncio.Semaphore(self.concurrency) |
| 240 | + trial_results: list[TrialResult] = [] |
| 241 | + |
| 242 | + async def _run_one(trial_index: int) -> TrialResult: |
| 243 | + async with semaphore: |
| 244 | + result = await self.evaluator.evaluate_session( |
| 245 | + session_id=session_id, |
| 246 | + golden_trajectory=golden_trajectory, |
| 247 | + golden_response=golden_response, |
| 248 | + match_type=match_type, |
| 249 | + task_description=task_description, |
| 250 | + use_llm_judge=use_llm_judge, |
| 251 | + thresholds=thresholds, |
| 252 | + ) |
| 253 | + return TrialResult( |
| 254 | + trial_index=trial_index, |
| 255 | + passed=result.eval_status == EvalStatus.PASSED, |
| 256 | + scores=result.scores, |
| 257 | + details=result.details, |
| 258 | + ) |
| 259 | + |
| 260 | + tasks = [_run_one(i) for i in range(self.num_trials)] |
| 261 | + trial_results = list(await asyncio.gather(*tasks)) |
| 262 | + |
| 263 | + return self._build_report(session_id, trial_results) |
| 264 | + |
| 265 | + async def run_trials_batch( |
| 266 | + self, |
| 267 | + eval_dataset: list[dict[str, Any]], |
| 268 | + match_type: MatchType = MatchType.EXACT, |
| 269 | + use_llm_judge: bool = False, |
| 270 | + ) -> list[MultiTrialReport]: |
| 271 | + """Runs multi-trial evaluation for a batch of tasks. |
| 272 | +
|
| 273 | + Args: |
| 274 | + eval_dataset: List of dicts with session_id, |
| 275 | + expected_trajectory, etc. |
| 276 | + match_type: Type of trajectory matching. |
| 277 | + use_llm_judge: Whether to use LLM-as-judge. |
| 278 | +
|
| 279 | + Returns: |
| 280 | + List of MultiTrialReport, one per task. |
| 281 | + """ |
| 282 | + reports = [] |
| 283 | + for item in eval_dataset: |
| 284 | + report = await self.run_trials( |
| 285 | + session_id=item["session_id"], |
| 286 | + golden_trajectory=item.get("expected_trajectory"), |
| 287 | + golden_response=item.get("expected_response"), |
| 288 | + match_type=match_type, |
| 289 | + task_description=item.get("task_description"), |
| 290 | + use_llm_judge=use_llm_judge, |
| 291 | + thresholds=item.get("thresholds"), |
| 292 | + ) |
| 293 | + reports.append(report) |
| 294 | + return reports |
| 295 | + |
| 296 | + def _build_report( |
| 297 | + self, |
| 298 | + session_id: str, |
| 299 | + trial_results: list[TrialResult], |
| 300 | + ) -> MultiTrialReport: |
| 301 | + """Builds a MultiTrialReport from trial results.""" |
| 302 | + num_trials = len(trial_results) |
| 303 | + if num_trials == 0: |
| 304 | + return MultiTrialReport( |
| 305 | + session_id=session_id, |
| 306 | + num_trials=0, |
| 307 | + ) |
| 308 | + |
| 309 | + num_passed = sum(1 for t in trial_results if t.passed) |
| 310 | + |
| 311 | + # Aggregate scores |
| 312 | + all_metric_names: set[str] = set() |
| 313 | + for t in trial_results: |
| 314 | + all_metric_names.update(t.scores.keys()) |
| 315 | + |
| 316 | + mean_scores: dict[str, float] = {} |
| 317 | + score_std_dev: dict[str, float] = {} |
| 318 | + |
| 319 | + for metric in sorted(all_metric_names): |
| 320 | + values = [t.scores.get(metric, 0.0) for t in trial_results] |
| 321 | + mean_scores[metric] = statistics.mean(values) |
| 322 | + if len(values) >= 2: |
| 323 | + score_std_dev[metric] = statistics.stdev(values) |
| 324 | + else: |
| 325 | + score_std_dev[metric] = 0.0 |
| 326 | + |
| 327 | + return MultiTrialReport( |
| 328 | + session_id=session_id, |
| 329 | + num_trials=num_trials, |
| 330 | + trial_results=trial_results, |
| 331 | + pass_at_k=compute_pass_at_k(num_trials, num_passed), |
| 332 | + pass_pow_k=compute_pass_pow_k(num_trials, num_passed), |
| 333 | + per_trial_pass_rate=num_passed / num_trials, |
| 334 | + mean_scores=mean_scores, |
| 335 | + score_std_dev=score_std_dev, |
| 336 | + ) |
0 commit comments