Skip to content

Commit a95ad44

Browse files
Merge branch 'main' into feat/quality-report
2 parents 33715c3 + cd0940a commit a95ad44

5 files changed

Lines changed: 951 additions & 62 deletions

File tree

src/bigquery_agent_analytics/categorical_evaluator.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
from pydantic import BaseModel
6767
from pydantic import Field
6868

69+
from bigquery_agent_analytics.evaluators import strip_markdown_fences
70+
6971
logger = logging.getLogger("bigquery_agent_analytics." + __name__)
7072

7173
DEFAULT_ENDPOINT = "gemini-2.5-flash"
@@ -127,6 +129,12 @@ class CategoricalEvaluationConfig(BaseModel):
127129
default=True,
128130
description="Include justification in output.",
129131
)
132+
max_output_tokens: int = Field(
133+
default=8192,
134+
ge=1,
135+
le=65536,
136+
description="Max output tokens for classification response.",
137+
)
130138
prompt_version: Optional[str] = Field(
131139
default=None,
132140
description="Tracks prompt version for reproducibility.",
@@ -241,6 +249,7 @@ def summary(self) -> str:
241249
WHERE {where}
242250
GROUP BY session_id
243251
HAVING LENGTH(transcript) > 10
252+
ORDER BY MAX(timestamp) DESC, session_id
244253
LIMIT @trace_limit
245254
"""
246255

@@ -267,6 +276,7 @@ def summary(self) -> str:
267276
WHERE {where}
268277
GROUP BY session_id
269278
HAVING LENGTH(transcript) > 10
279+
ORDER BY MAX(timestamp) DESC, session_id
270280
LIMIT @trace_limit
271281
)
272282
SELECT
@@ -278,7 +288,7 @@ def summary(self) -> str:
278288
'\\n\\nTranscript:\\n', transcript
279289
),
280290
endpoint => '{endpoint}',
281-
model_params => JSON '{{"generationConfig": {{"temperature": {temperature}, "maxOutputTokens": 1024}}}}',
291+
model_params => JSON '{{"generationConfig": {{"temperature": {temperature}, "maxOutputTokens": {max_output_tokens}}}}}',
282292
output_schema => 'classifications STRING'
283293
)).classifications AS classifications
284294
FROM session_transcripts
@@ -388,6 +398,7 @@ def build_ai_classify_query(
388398
WHERE {where}
389399
GROUP BY session_id
390400
HAVING LENGTH(transcript) > 10
401+
ORDER BY MAX(timestamp) DESC, session_id
391402
LIMIT @trace_limit
392403
)
393404
SELECT
@@ -411,6 +422,7 @@ def build_ai_generate_query(
411422
endpoint: str,
412423
temperature: float,
413424
connection_id: Optional[str] = None,
425+
max_output_tokens: int = 8192,
414426
) -> str:
415427
"""Builds the AI.GENERATE categorical classification query.
416428
@@ -457,6 +469,7 @@ def build_ai_generate_query(
457469
WHERE {where}
458470
GROUP BY session_id
459471
HAVING LENGTH(transcript) > 10
472+
ORDER BY MAX(timestamp) DESC, session_id
460473
LIMIT @trace_limit
461474
)
462475
SELECT
@@ -468,7 +481,7 @@ def build_ai_generate_query(
468481
'\\n\\nTranscript:\\n', transcript
469482
),{connection_clause}
470483
endpoint => '{_escape_sql_string_literal(endpoint)}',
471-
model_params => JSON '{{"generationConfig": {{"temperature": {temperature}, "maxOutputTokens": 1024}}}}',
484+
model_params => JSON '{{"generationConfig": {{"temperature": {temperature}, "maxOutputTokens": {max_output_tokens}}}}}',
472485
output_schema => 'classifications STRING'
473486
)).classifications AS classifications
474487
FROM session_transcripts
@@ -658,8 +671,12 @@ def parse_classifications(
658671
for m in config.metrics
659672
]
660673

674+
# Strip markdown code blocks (```json ... ```) that models often wrap
675+
# around JSON output. Uses the shared helper from evaluators.py.
676+
text = strip_markdown_fences(raw_json)
677+
661678
try:
662-
parsed = json.loads(raw_json)
679+
parsed = json.loads(text)
663680
except (json.JSONDecodeError, TypeError):
664681
return [
665682
CategoricalMetricResult(
@@ -852,11 +869,24 @@ async def classify_sessions_via_api(
852869
contents=full_prompt,
853870
config=types.GenerateContentConfig(
854871
temperature=config.temperature,
855-
max_output_tokens=1024,
872+
max_output_tokens=config.max_output_tokens,
856873
),
857874
)
858-
raw_text = response.text.strip()
875+
raw_text = response.text.strip() if response.text else ""
859876
metrics = parse_classifications(raw_text, config)
877+
has_parse_error = any(m.parse_error for m in metrics)
878+
if has_parse_error:
879+
finish_reason = None
880+
if response.candidates:
881+
finish_reason = response.candidates[0].finish_reason
882+
logger.warning(
883+
"API parse error for session %s: finish_reason=%s, "
884+
"raw_text_len=%d, raw_text=%s",
885+
sid,
886+
finish_reason,
887+
len(raw_text),
888+
repr(raw_text[:500]),
889+
)
860890
results.append(
861891
CategoricalSessionResult(
862892
session_id=sid,
@@ -865,9 +895,10 @@ async def classify_sessions_via_api(
865895
)
866896
except Exception as e:
867897
logger.warning(
868-
"Categorical API classification failed for %s: %s",
898+
"Categorical API classification EXCEPTION for %s: %s (type=%s)",
869899
sid,
870900
e,
901+
type(e).__name__,
871902
)
872903
results.append(
873904
CategoricalSessionResult(

src/bigquery_agent_analytics/client.py

Lines changed: 139 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from datetime import datetime
5050
from datetime import timezone
5151
import logging
52+
import time
5253
from typing import Any, Optional
5354

5455
from google.cloud import bigquery
@@ -134,9 +135,11 @@
134135

135136
_LIST_TRACES_QUERY = """\
136137
WITH trace_sessions AS (
137-
SELECT DISTINCT session_id
138+
SELECT session_id
138139
FROM `{project}.{dataset}.{table}`
139140
WHERE {where}
141+
GROUP BY session_id
142+
ORDER BY MAX(timestamp) DESC, session_id
140143
LIMIT @trace_limit
141144
)
142145
SELECT
@@ -1250,7 +1253,7 @@ def evaluate_categorical(
12501253

12511254
# Try AI.GENERATE.
12521255
try:
1253-
session_results = self._categorical_ai_generate(
1256+
session_results, retry_meta = self._categorical_ai_generate(
12541257
config,
12551258
table,
12561259
where,
@@ -1264,6 +1267,8 @@ def evaluate_categorical(
12641267
config=config,
12651268
)
12661269
report.details["execution_mode"] = "ai_generate"
1270+
if retry_meta:
1271+
report.details["retry"] = retry_meta
12671272
if classify_fallback_reason:
12681273
report.details["classify_fallback_reason"] = classify_fallback_reason
12691274
self._persist_categorical_if_configured(report, config, endpoint)
@@ -1355,8 +1360,13 @@ def _categorical_ai_generate(
13551360
params: list,
13561361
endpoint: str,
13571362
connection_id: Optional[str] = None,
1358-
) -> list:
1359-
"""Classifies sessions using BigQuery AI.GENERATE."""
1363+
) -> tuple[list, dict]:
1364+
"""Classifies sessions using BigQuery AI.GENERATE.
1365+
1366+
Sessions where AI.GENERATE returns NULL (e.g. due to rate
1367+
limiting or transient errors) are retried via the Gemini API
1368+
up to 3 times.
1369+
"""
13601370
prompt = build_categorical_prompt(config)
13611371

13621372
query = build_ai_generate_query(
@@ -1367,6 +1377,7 @@ def _categorical_ai_generate(
13671377
endpoint=endpoint,
13681378
temperature=config.temperature,
13691379
connection_id=connection_id,
1380+
max_output_tokens=config.max_output_tokens,
13701381
)
13711382

13721383
query_params = list(params) + [
@@ -1383,11 +1394,133 @@ def _categorical_ai_generate(
13831394
results = list(self.bq_client.query(query, job_config=job_config).result())
13841395

13851396
session_results = []
1397+
failed_sessions = {}
13861398
for row in results:
13871399
r = dict(row)
13881400
sid = r.get("session_id", "unknown")
1389-
session_results.append(parse_categorical_row(sid, r, config))
1390-
return session_results
1401+
parsed = parse_categorical_row(sid, r, config)
1402+
has_parse_error = any(m.parse_error for m in parsed.metrics)
1403+
if has_parse_error and r.get("transcript"):
1404+
failed_sessions[sid] = r.get("transcript", "")
1405+
session_results.append(parsed)
1406+
1407+
retry_meta = {}
1408+
if failed_sessions:
1409+
logger.warning(
1410+
"AI.GENERATE returned NULL/unparseable for %d session(s), "
1411+
"retrying via Gemini API: %s",
1412+
len(failed_sessions),
1413+
", ".join(failed_sessions.keys()),
1414+
)
1415+
retried = self._retry_failed_sessions(
1416+
failed_sessions,
1417+
config,
1418+
endpoint,
1419+
max_retries=3,
1420+
)
1421+
resolved = 0
1422+
if retried:
1423+
retried_map = {r.session_id: r for r in retried}
1424+
session_results = [
1425+
retried_map.get(sr.session_id, sr) for sr in session_results
1426+
]
1427+
resolved = sum(
1428+
1 for r in retried if not any(m.parse_error for m in r.metrics)
1429+
)
1430+
logger.info(
1431+
"Gemini API retry resolved %d/%d failed sessions",
1432+
resolved,
1433+
len(failed_sessions),
1434+
)
1435+
retry_meta = {
1436+
"failed_count": len(failed_sessions),
1437+
"retry_attempted": True,
1438+
"retry_resolved": resolved,
1439+
"retry_unresolved": len(failed_sessions) - resolved,
1440+
}
1441+
1442+
return session_results, retry_meta
1443+
1444+
def _retry_failed_sessions(
1445+
self,
1446+
transcripts: dict[str, str],
1447+
config: CategoricalEvaluationConfig,
1448+
endpoint: str,
1449+
max_retries: int = 3,
1450+
) -> list:
1451+
"""Retries classification for failed sessions via Gemini API.
1452+
1453+
Note: This method is synchronous and must not be called from
1454+
an async context with an already-running event loop.
1455+
1456+
Args:
1457+
transcripts: Maps session_id to transcript text.
1458+
config: Evaluation config.
1459+
endpoint: Model endpoint.
1460+
max_retries: Maximum number of retry attempts.
1461+
1462+
Returns:
1463+
List of CategoricalSessionResult for successfully retried
1464+
sessions.
1465+
"""
1466+
remaining = dict(transcripts)
1467+
all_results = {}
1468+
1469+
for attempt in range(1, max_retries + 1):
1470+
if not remaining:
1471+
break
1472+
if attempt > 1:
1473+
backoff = 2 ** (attempt - 2)
1474+
logger.info(
1475+
"Retry backoff: sleeping %ds before attempt %d", backoff, attempt
1476+
)
1477+
time.sleep(backoff)
1478+
try:
1479+
results = _run_sync(
1480+
classify_sessions_via_api(remaining, config, endpoint)
1481+
)
1482+
still_failed = {}
1483+
for r in results:
1484+
has_error = any(m.parse_error for m in r.metrics)
1485+
if has_error:
1486+
if r.session_id in remaining:
1487+
still_failed[r.session_id] = remaining[r.session_id]
1488+
for m in r.metrics:
1489+
if m.parse_error:
1490+
logger.warning(
1491+
"Retry attempt %d, session %s, metric %s: "
1492+
"parse_error=True, raw_response=%s",
1493+
attempt,
1494+
r.session_id,
1495+
m.metric_name,
1496+
repr(m.raw_response[:500] if m.raw_response else None),
1497+
)
1498+
break
1499+
else:
1500+
all_results[r.session_id] = r
1501+
remaining = still_failed
1502+
if remaining:
1503+
logger.warning(
1504+
"Retry attempt %d: %d sessions still unresolved",
1505+
attempt,
1506+
len(remaining),
1507+
)
1508+
except Exception as e: # Broad catch: retry loop logs + continues
1509+
logger.warning(
1510+
"Gemini API retry attempt %d failed: %s (type=%s)",
1511+
attempt,
1512+
e,
1513+
type(e).__name__,
1514+
)
1515+
1516+
if remaining:
1517+
logger.warning(
1518+
"%d sessions still unresolved after %d retries",
1519+
len(remaining),
1520+
max_retries,
1521+
)
1522+
1523+
return list(all_results.values())
13911524

13921525
def _categorical_api_fallback(
13931526
self,

0 commit comments

Comments
 (0)