4949from datetime import datetime
5050from datetime import timezone
5151import logging
52+ import time
5253from typing import Any , Optional
5354
5455from google .cloud import bigquery
134135
135136_LIST_TRACES_QUERY = """\
136137 WITH trace_sessions AS (
137- SELECT DISTINCT session_id
138+ SELECT session_id
138139 FROM `{project}.{dataset}.{table}`
139140 WHERE {where}
141+ GROUP BY session_id
142+ ORDER BY MAX(timestamp) DESC, session_id
140143 LIMIT @trace_limit
141144)
142145SELECT
@@ -1250,7 +1253,7 @@ def evaluate_categorical(
12501253
12511254 # Try AI.GENERATE.
12521255 try :
1253- session_results = self ._categorical_ai_generate (
1256+ session_results , retry_meta = self ._categorical_ai_generate (
12541257 config ,
12551258 table ,
12561259 where ,
@@ -1264,6 +1267,8 @@ def evaluate_categorical(
12641267 config = config ,
12651268 )
12661269 report .details ["execution_mode" ] = "ai_generate"
1270+ if retry_meta :
1271+ report .details ["retry" ] = retry_meta
12671272 if classify_fallback_reason :
12681273 report .details ["classify_fallback_reason" ] = classify_fallback_reason
12691274 self ._persist_categorical_if_configured (report , config , endpoint )
@@ -1355,8 +1360,13 @@ def _categorical_ai_generate(
13551360 params : list ,
13561361 endpoint : str ,
13571362 connection_id : Optional [str ] = None ,
1358- ) -> list :
1359- """Classifies sessions using BigQuery AI.GENERATE."""
1363+ ) -> tuple [list , dict ]:
1364+ """Classifies sessions using BigQuery AI.GENERATE.
1365+
1366+ Sessions where AI.GENERATE returns NULL (e.g. due to rate
1367+ limiting or transient errors) are retried via the Gemini API
1368+ up to 3 times.
1369+ """
13601370 prompt = build_categorical_prompt (config )
13611371
13621372 query = build_ai_generate_query (
@@ -1367,6 +1377,7 @@ def _categorical_ai_generate(
13671377 endpoint = endpoint ,
13681378 temperature = config .temperature ,
13691379 connection_id = connection_id ,
1380+ max_output_tokens = config .max_output_tokens ,
13701381 )
13711382
13721383 query_params = list (params ) + [
@@ -1383,11 +1394,133 @@ def _categorical_ai_generate(
13831394 results = list (self .bq_client .query (query , job_config = job_config ).result ())
13841395
13851396 session_results = []
1397+ failed_sessions = {}
13861398 for row in results :
13871399 r = dict (row )
13881400 sid = r .get ("session_id" , "unknown" )
1389- session_results .append (parse_categorical_row (sid , r , config ))
1390- return session_results
1401+ parsed = parse_categorical_row (sid , r , config )
1402+ has_parse_error = any (m .parse_error for m in parsed .metrics )
1403+ if has_parse_error and r .get ("transcript" ):
1404+ failed_sessions [sid ] = r .get ("transcript" , "" )
1405+ session_results .append (parsed )
1406+
1407+ retry_meta = {}
1408+ if failed_sessions :
1409+ logger .warning (
1410+ "AI.GENERATE returned NULL/unparseable for %d session(s), "
1411+ "retrying via Gemini API: %s" ,
1412+ len (failed_sessions ),
1413+ ", " .join (failed_sessions .keys ()),
1414+ )
1415+ retried = self ._retry_failed_sessions (
1416+ failed_sessions ,
1417+ config ,
1418+ endpoint ,
1419+ max_retries = 3 ,
1420+ )
1421+ resolved = 0
1422+ if retried :
1423+ retried_map = {r .session_id : r for r in retried }
1424+ session_results = [
1425+ retried_map .get (sr .session_id , sr ) for sr in session_results
1426+ ]
1427+ resolved = sum (
1428+ 1 for r in retried if not any (m .parse_error for m in r .metrics )
1429+ )
1430+ logger .info (
1431+ "Gemini API retry resolved %d/%d failed sessions" ,
1432+ resolved ,
1433+ len (failed_sessions ),
1434+ )
1435+ retry_meta = {
1436+ "failed_count" : len (failed_sessions ),
1437+ "retry_attempted" : True ,
1438+ "retry_resolved" : resolved ,
1439+ "retry_unresolved" : len (failed_sessions ) - resolved ,
1440+ }
1441+
1442+ return session_results , retry_meta
1443+
1444+ def _retry_failed_sessions (
1445+ self ,
1446+ transcripts : dict [str , str ],
1447+ config : CategoricalEvaluationConfig ,
1448+ endpoint : str ,
1449+ max_retries : int = 3 ,
1450+ ) -> list :
1451+ """Retries classification for failed sessions via Gemini API.
1452+
1453+ Note: This method is synchronous and must not be called from
1454+ an async context with an already-running event loop.
1455+
1456+ Args:
1457+ transcripts: Maps session_id to transcript text.
1458+ config: Evaluation config.
1459+ endpoint: Model endpoint.
1460+ max_retries: Maximum number of retry attempts.
1461+
1462+ Returns:
1463+ List of CategoricalSessionResult for successfully retried
1464+ sessions.
1465+ """
1466+ remaining = dict (transcripts )
1467+ all_results = {}
1468+
1469+ for attempt in range (1 , max_retries + 1 ):
1470+ if not remaining :
1471+ break
1472+ if attempt > 1 :
1473+ backoff = 2 ** (attempt - 2 )
1474+ logger .info (
1475+ "Retry backoff: sleeping %ds before attempt %d" , backoff , attempt
1476+ )
1477+ time .sleep (backoff )
1478+ try :
1479+ results = _run_sync (
1480+ classify_sessions_via_api (remaining , config , endpoint )
1481+ )
1482+ still_failed = {}
1483+ for r in results :
1484+ has_error = any (m .parse_error for m in r .metrics )
1485+ if has_error :
1486+ if r .session_id in remaining :
1487+ still_failed [r .session_id ] = remaining [r .session_id ]
1488+ for m in r .metrics :
1489+ if m .parse_error :
1490+ logger .warning (
1491+ "Retry attempt %d, session %s, metric %s: "
1492+ "parse_error=True, raw_response=%s" ,
1493+ attempt ,
1494+ r .session_id ,
1495+ m .metric_name ,
1496+ repr (m .raw_response [:500 ] if m .raw_response else None ),
1497+ )
1498+ break
1499+ else :
1500+ all_results [r .session_id ] = r
1501+ remaining = still_failed
1502+ if remaining :
1503+ logger .warning (
1504+ "Retry attempt %d: %d sessions still unresolved" ,
1505+ attempt ,
1506+ len (remaining ),
1507+ )
1508+ except Exception as e : # Broad catch: retry loop logs + continues
1509+ logger .warning (
1510+ "Gemini API retry attempt %d failed: %s (type=%s)" ,
1511+ attempt ,
1512+ e ,
1513+ type (e ).__name__ ,
1514+ )
1515+
1516+ if remaining :
1517+ logger .warning (
1518+ "%d sessions still unresolved after %d retries" ,
1519+ len (remaining ),
1520+ max_retries ,
1521+ )
1522+
1523+ return list (all_results .values ())
13911524
13921525 def _categorical_api_fallback (
13931526 self ,
0 commit comments