2121import static com .google .adk .plugins .agentanalytics .JsonFormatter .convertToJsonNode ;
2222import static com .google .adk .plugins .agentanalytics .JsonFormatter .smartTruncate ;
2323import static com .google .adk .plugins .agentanalytics .JsonFormatter .toJavaObject ;
24- import static java .util .concurrent .TimeUnit .MILLISECONDS ;
2524
2625import com .google .adk .agents .BaseAgent ;
2726import com .google .adk .agents .CallbackContext ;
4140import com .google .adk .tools .ToolContext ;
4241import com .google .adk .tools .mcp .AbstractMcpTool ;
4342import com .google .adk .utils .AgentEnums .AgentOrigin ;
44- import com .google .api .gax .core .FixedCredentialsProvider ;
45- import com .google .api .gax .retrying .RetrySettings ;
4643import com .google .auth .oauth2 .GoogleCredentials ;
4744import com .google .cloud .bigquery .BigQuery ;
4845import com .google .cloud .bigquery .BigQueryException ;
5350import com .google .cloud .bigquery .Table ;
5451import com .google .cloud .bigquery .TableId ;
5552import com .google .cloud .bigquery .TableInfo ;
56- import com .google .cloud .bigquery .storage .v1 .BigQueryWriteClient ;
57- import com .google .cloud .bigquery .storage .v1 .BigQueryWriteSettings ;
58- import com .google .cloud .bigquery .storage .v1 .StreamWriter ;
5953import com .google .common .annotations .VisibleForTesting ;
60- import com .google .common .base .VerifyException ;
6154import com .google .common .collect .ImmutableList ;
6255import com .google .common .collect .ImmutableMap ;
6356import com .google .genai .types .Content ;
7063import java .util .HashMap ;
7164import java .util .Map ;
7265import java .util .Optional ;
73- import java .util .concurrent .Executors ;
74- import java .util .concurrent .ScheduledExecutorService ;
75- import java .util .concurrent .ThreadFactory ;
7666import java .util .concurrent .atomic .AtomicLong ;
7767import java .util .logging .Level ;
7868import java .util .logging .Logger ;
@@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
10090
10191 private final BigQueryLoggerConfig config ;
10292 private final BigQuery bigQuery ;
103- private final BigQueryWriteClient writeClient ;
104- private final ScheduledExecutorService executor ;
10593 private final Object tableEnsuredLock = new Object ();
106- @ VisibleForTesting final BatchProcessor batchProcessor ;
107- @ VisibleForTesting final TraceManager traceManager ;
94+ private final PluginState state ;
10895 private volatile boolean tableEnsured = false ;
10996
11097 public BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config ) throws IOException {
@@ -113,28 +100,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept
113100
114101 public BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config , BigQuery bigQuery )
115102 throws IOException {
103+ this (config , bigQuery , new PluginState (config ));
104+ }
105+
106+ BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config , BigQuery bigQuery , PluginState state ) {
116107 super ("bigquery_agent_analytics" );
117108 this .config = config ;
118109 this .bigQuery = bigQuery ;
119- ThreadFactory threadFactory =
120- r -> new Thread (r , "bq-analytics-plugin-" + threadCounter .getAndIncrement ());
121- this .executor = Executors .newScheduledThreadPool (1 , threadFactory );
122- this .writeClient = createWriteClient (config );
123- this .traceManager = createTraceManager ();
124-
125- if (config .enabled ()) {
126- StreamWriter writer = createWriter (config );
127- this .batchProcessor =
128- new BatchProcessor (
129- writer ,
130- config .batchSize (),
131- config .batchFlushInterval (),
132- config .queueMaxSize (),
133- executor );
134- this .batchProcessor .start ();
135- } else {
136- this .batchProcessor = null ;
137- }
110+ this .state = state ;
138111 }
139112
140113 private static BigQuery createBigQuery (BigQueryLoggerConfig config ) throws IOException {
@@ -194,7 +167,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) {
194167
195168 try {
196169 if (config .createViews ()) {
197- var unused = executor .submit (() -> createAnalyticsViews (bigQuery , config ));
170+ var unused = state . getExecutor () .submit (() -> createAnalyticsViews (bigQuery , config ));
198171 }
199172 } catch (RuntimeException e ) {
200173 logger .log (Level .WARNING , "Failed to create/update BigQuery views for table: " + tableId , e );
@@ -209,48 +182,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) {
209182 }
210183 }
211184
212- protected BigQueryWriteClient createWriteClient (BigQueryLoggerConfig config ) throws IOException {
213- if (config .credentials () != null ) {
214- return BigQueryWriteClient .create (
215- BigQueryWriteSettings .newBuilder ()
216- .setCredentialsProvider (FixedCredentialsProvider .create (config .credentials ()))
217- .build ());
218- }
219- return BigQueryWriteClient .create ();
220- }
221-
222- protected String getStreamName (BigQueryLoggerConfig config ) {
223- return String .format (
224- "projects/%s/datasets/%s/tables/%s/streams/_default" ,
225- config .projectId (), config .datasetId (), config .tableName ());
226- }
227-
228- protected StreamWriter createWriter (BigQueryLoggerConfig config ) {
229- BigQueryLoggerConfig .RetryConfig retryConfig = config .retryConfig ();
230- RetrySettings retrySettings =
231- RetrySettings .newBuilder ()
232- .setMaxAttempts (retryConfig .maxRetries ())
233- .setInitialRetryDelay (
234- org .threeten .bp .Duration .ofMillis (retryConfig .initialDelay ().toMillis ()))
235- .setRetryDelayMultiplier (retryConfig .multiplier ())
236- .setMaxRetryDelay (org .threeten .bp .Duration .ofMillis (retryConfig .maxDelay ().toMillis ()))
237- .build ();
238-
239- String streamName = getStreamName (config );
240- try {
241- return StreamWriter .newBuilder (streamName , writeClient )
242- .setRetrySettings (retrySettings )
243- .setWriterSchema (BigQuerySchema .getArrowSchema ())
244- .build ();
245- } catch (Exception e ) {
246- throw new VerifyException ("Failed to create StreamWriter for " + streamName , e );
247- }
248- }
249-
250- protected TraceManager createTraceManager () {
251- return new TraceManager ();
252- }
253-
254185 private void logEvent (
255186 String eventType ,
256187 InvocationContext invocationContext ,
@@ -265,7 +196,7 @@ private void logEvent(
265196 Object content ,
266197 boolean isContentTruncated ,
267198 Optional <EventData > eventData ) {
268- if (!config .enabled () || batchProcessor == null ) {
199+ if (!config .enabled ()) {
269200 return ;
270201 }
271202 if (!config .eventAllowlist ().isEmpty () && !config .eventAllowlist ().contains (eventType )) {
@@ -274,6 +205,8 @@ private void logEvent(
274205 if (config .eventDenylist ().contains (eventType )) {
275206 return ;
276207 }
208+ String invocationId = invocationContext .invocationId ();
209+ BatchProcessor processor = state .getBatchProcessor (invocationId );
277210 // Ensure table exists before logging.
278211 ensureTableExistsOnce ();
279212 // Log common fields
@@ -301,11 +234,12 @@ private void logEvent(
301234 row .put ("attributes" , convertToJsonNode (getAttributes (data , invocationContext )));
302235
303236 addTraceDetails (row , invocationContext , eventData );
304- batchProcessor .append (row );
237+ processor .append (row );
305238 }
306239
307240 private void addTraceDetails (
308241 Map <String , Object > row , InvocationContext invocationContext , Optional <EventData > eventData ) {
242+ TraceManager traceManager = state .getTraceManager (invocationContext .invocationId ());
309243 String traceId =
310244 eventData
311245 .flatMap (EventData ::traceIdOverride )
@@ -336,7 +270,7 @@ private void addTraceDetails(
336270 private Map <String , Object > getAttributes (
337271 EventData eventData , InvocationContext invocationContext ) {
338272 Map <String , Object > attributes = new HashMap <>(eventData .extraAttributes ());
339-
273+ TraceManager traceManager = state . getTraceManager ( invocationContext . invocationId ());
340274 attributes .put ("root_agent_name" , traceManager .getRootAgentName ());
341275 eventData .model ().ifPresent (m -> attributes .put ("model" , m ));
342276 eventData .modelVersion ().ifPresent (mv -> attributes .put ("model_version" , mv ));
@@ -375,25 +309,17 @@ private Map<String, Object> getAttributes(
375309
376310 @ Override
377311 public Completable close () {
378- if (batchProcessor != null ) {
379- batchProcessor .close ();
380- }
381- if (writeClient != null ) {
382- writeClient .close ();
383- }
384- try {
385- executor .shutdown ();
386- if (!executor .awaitTermination (config .shutdownTimeout ().toMillis (), MILLISECONDS )) {
387- executor .shutdownNow ();
388- }
389- } catch (InterruptedException e ) {
390- executor .shutdownNow ();
391- Thread .currentThread ().interrupt ();
392- }
312+ state .close ();
393313 return Completable .complete ();
394314 }
395315
316+ @ VisibleForTesting
317+ PluginState getState () {
318+ return state ;
319+ }
320+
396321 private Optional <EventData > getCompletedEventData (InvocationContext invocationContext ) {
322+ TraceManager traceManager = state .getTraceManager (invocationContext .invocationId ());
397323 String traceId = traceManager .getTraceId (invocationContext );
398324 // Pop the invocation span from the trace manager.
399325 Optional <RecordData > popped = traceManager .popSpan ();
@@ -426,7 +352,9 @@ public Maybe<Content> onUserMessageCallback(
426352 InvocationContext invocationContext , Content userMessage ) {
427353 return Maybe .fromAction (
428354 () -> {
429- traceManager .ensureInvocationSpan (invocationContext );
355+ state
356+ .getTraceManager (invocationContext .invocationId ())
357+ .ensureInvocationSpan (invocationContext );
430358 logEvent ("USER_MESSAGE_RECEIVED" , invocationContext , userMessage , Optional .empty ());
431359 if (userMessage .parts ().isPresent ()) {
432360 for (Part part : userMessage .parts ().get ()) {
@@ -510,7 +438,7 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
510438
511439 @ Override
512440 public Maybe <Content > beforeRunCallback (InvocationContext invocationContext ) {
513- traceManager .ensureInvocationSpan (invocationContext );
441+ state . getTraceManager ( invocationContext . invocationId ()) .ensureInvocationSpan (invocationContext );
514442 return Maybe .fromAction (
515443 () -> logEvent ("INVOCATION_STARTING" , invocationContext , null , Optional .empty ()));
516444 }
@@ -524,16 +452,25 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
524452 invocationContext ,
525453 null ,
526454 getCompletedEventData (invocationContext ));
527- batchProcessor .flush ();
528- traceManager .clearStack ();
455+ BatchProcessor processor = state .removeProcessor (invocationContext .invocationId ());
456+ if (processor != null ) {
457+ processor .flush ();
458+ processor .close ();
459+ }
460+ TraceManager traceManager = state .removeTraceManager (invocationContext .invocationId ());
461+ if (traceManager != null ) {
462+ traceManager .clearStack ();
463+ }
529464 });
530465 }
531466
532467 @ Override
533468 public Maybe <Content > beforeAgentCallback (BaseAgent agent , CallbackContext callbackContext ) {
534469 return Maybe .fromAction (
535470 () -> {
536- traceManager .pushSpan ("agent:" + agent .name ());
471+ state
472+ .getTraceManager (callbackContext .invocationContext ().invocationId ())
473+ .pushSpan ("agent:" + agent .name ());
537474 logEvent ("AGENT_STARTING" , callbackContext .invocationContext (), null , Optional .empty ());
538475 });
539476 }
@@ -622,7 +559,9 @@ public Maybe<LlmResponse> beforeModelCallback(
622559 .setModel (req .model ().orElse ("" ))
623560 .setExtraAttributes (attributes )
624561 .build ();
625- traceManager .pushSpan ("llm_request" );
562+ state
563+ .getTraceManager (callbackContext .invocationContext ().invocationId ())
564+ .pushSpan ("llm_request" );
626565 logEvent ("LLM_REQUEST" , callbackContext .invocationContext (), req , Optional .of (eventData ));
627566 });
628567 }
@@ -632,6 +571,8 @@ public Maybe<LlmResponse> afterModelCallback(
632571 CallbackContext callbackContext , LlmResponse llmResponse ) {
633572 return Maybe .fromAction (
634573 () -> {
574+ TraceManager traceManager =
575+ state .getTraceManager (callbackContext .invocationContext ().invocationId ());
635576 // TODO(b/495809488): Add formatting of the content
636577 ParsedContent parsedContent =
637578 JsonFormatter .parse (llmResponse .content ().orElse (null ), config .maxContentLength ());
@@ -728,6 +669,8 @@ public Maybe<LlmResponse> onModelErrorCallback(
728669 CallbackContext callbackContext , LlmRequest .Builder llmRequest , Throwable error ) {
729670 return Maybe .fromAction (
730671 () -> {
672+ TraceManager traceManager =
673+ state .getTraceManager (callbackContext .invocationContext ().invocationId ());
731674 InvocationContext invocationContext = callbackContext .invocationContext ();
732675 Optional <RecordData > popped = traceManager .popSpan ();
733676 String spanId = popped .map (RecordData ::spanId ).orElse (null );
@@ -762,7 +705,7 @@ public Maybe<Map<String, Object>> beforeToolCallback(
762705 ImmutableMap <String , Object > contentMap =
763706 ImmutableMap .of (
764707 "tool_origin" , getToolOrigin (tool ), "tool" , tool .name (), "args" , res .node ());
765- traceManager .pushSpan ("tool" );
708+ state . getTraceManager ( toolContext . invocationContext (). invocationId ()) .pushSpan ("tool" );
766709 logEvent ("TOOL_STARTING" , toolContext .invocationContext (), contentMap , Optional .empty ());
767710 });
768711 }
@@ -775,6 +718,8 @@ public Maybe<Map<String, Object>> afterToolCallback(
775718 Map <String , Object > result ) {
776719 return Maybe .fromAction (
777720 () -> {
721+ TraceManager traceManager =
722+ state .getTraceManager (toolContext .invocationContext ().invocationId ());
778723 Optional <RecordData > popped = traceManager .popSpan ();
779724 TruncationResult truncationResult = smartTruncate (result , config .maxContentLength ());
780725 ImmutableMap <String , Object > contentMap =
@@ -812,6 +757,8 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
812757 BaseTool tool , Map <String , Object > toolArgs , ToolContext toolContext , Throwable error ) {
813758 return Maybe .fromAction (
814759 () -> {
760+ TraceManager traceManager =
761+ state .getTraceManager (toolContext .invocationContext ().invocationId ());
815762 Optional <RecordData > popped = traceManager .popSpan ();
816763 TruncationResult truncationResult = smartTruncate (toolArgs , config .maxContentLength ());
817764 String toolOrigin = getToolOrigin (tool );
0 commit comments