6565import java .io .IOException ;
6666import java .time .Duration ;
6767import java .time .Instant ;
68+ import java .util .Collection ;
6869import java .util .HashMap ;
6970import java .util .Map ;
7071import java .util .Optional ;
72+ import java .util .concurrent .ConcurrentHashMap ;
7173import java .util .concurrent .Executors ;
7274import java .util .concurrent .ScheduledExecutorService ;
7375import java .util .concurrent .ThreadFactory ;
@@ -101,8 +103,11 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
101103 private final BigQueryWriteClient writeClient ;
102104 private final ScheduledExecutorService executor ;
103105 private final Object tableEnsuredLock = new Object ();
104- @ VisibleForTesting final BatchProcessor batchProcessor ;
105- @ VisibleForTesting final TraceManager traceManager ;
106+ // Map of invocation ID to BatchProcessor.
107+ private final ConcurrentHashMap <String , BatchProcessor > batchProcessors =
108+ new ConcurrentHashMap <>();
109+ // Map of invocation ID to TraceManager.
110+ private final ConcurrentHashMap <String , TraceManager > traceManagers = new ConcurrentHashMap <>();
106111 private volatile boolean tableEnsured = false ;
107112
108113 public BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config ) throws IOException {
@@ -118,21 +123,16 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQue
118123 r -> new Thread (r , "bq-analytics-plugin-" + threadCounter .getAndIncrement ());
119124 this .executor = Executors .newScheduledThreadPool (1 , threadFactory );
120125 this .writeClient = createWriteClient (config );
121- this .traceManager = createTraceManager ();
122-
123- if (config .enabled ()) {
124- StreamWriter writer = createWriter (config );
125- this .batchProcessor =
126- new BatchProcessor (
127- writer ,
128- config .batchSize (),
129- config .batchFlushInterval (),
130- config .queueMaxSize (),
131- executor );
132- this .batchProcessor .start ();
133- } else {
134- this .batchProcessor = null ;
135- }
126+ }
127+
128+ @ VisibleForTesting
129+ TraceManager getTraceManager (String invocationId ) {
130+ return traceManagers .computeIfAbsent (invocationId , id -> new TraceManager ());
131+ }
132+
133+ @ VisibleForTesting
134+ BatchProcessor getBatchProcessor (String invocationId ) {
135+ return batchProcessors .get (invocationId );
136136 }
137137
138138 private static BigQuery createBigQuery (BigQueryLoggerConfig config ) throws IOException {
@@ -234,10 +234,6 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) {
234234 }
235235 }
236236
237- protected TraceManager createTraceManager () {
238- return new TraceManager ();
239- }
240-
241237 private void logEvent (
242238 String eventType ,
243239 InvocationContext invocationContext ,
@@ -252,7 +248,7 @@ private void logEvent(
252248 Object content ,
253249 boolean isContentTruncated ,
254250 Optional <EventData > eventData ) {
255- if (!config .enabled () || batchProcessor == null ) {
251+ if (!config .enabled ()) {
256252 return ;
257253 }
258254 if (!config .eventAllowlist ().isEmpty () && !config .eventAllowlist ().contains (eventType )) {
@@ -261,6 +257,26 @@ private void logEvent(
261257 if (config .eventDenylist ().contains (eventType )) {
262258 return ;
263259 }
260+ String invocationId = invocationContext .invocationId ();
261+ BatchProcessor processor =
262+ batchProcessors .computeIfAbsent (
263+ invocationId ,
264+ id -> {
265+ StreamWriter writer = createWriter (config );
266+ BatchProcessor p =
267+ new BatchProcessor (
268+ writer ,
269+ config .batchSize (),
270+ config .batchFlushInterval (),
271+ config .queueMaxSize (),
272+ executor );
273+ p .start ();
274+ return p ;
275+ });
276+ if (processor == null ) {
277+ logger .severe ("Failed to create BatchProcessor for invocationId: " + invocationId );
278+ return ;
279+ }
264280 // Ensure table exists before logging.
265281 ensureTableExistsOnce ();
266282 // Log common fields
@@ -288,11 +304,12 @@ private void logEvent(
288304 row .put ("attributes" , convertToJsonNode (getAttributes (data , invocationContext )));
289305
290306 addTraceDetails (row , invocationContext , eventData );
291- batchProcessor .append (row );
307+ processor .append (row );
292308 }
293309
294310 private void addTraceDetails (
295311 Map <String , Object > row , InvocationContext invocationContext , Optional <EventData > eventData ) {
312+ TraceManager traceManager = getTraceManager (invocationContext .invocationId ());
296313 String traceId =
297314 eventData
298315 .flatMap (EventData ::traceIdOverride )
@@ -323,7 +340,7 @@ private void addTraceDetails(
323340 private Map <String , Object > getAttributes (
324341 EventData eventData , InvocationContext invocationContext ) {
325342 Map <String , Object > attributes = new HashMap <>(eventData .extraAttributes ());
326-
343+ TraceManager traceManager = getTraceManager ( invocationContext . invocationId ());
327344 attributes .put ("root_agent_name" , traceManager .getRootAgentName ());
328345 eventData .model ().ifPresent (m -> attributes .put ("model" , m ));
329346 eventData .modelVersion ().ifPresent (mv -> attributes .put ("model_version" , mv ));
@@ -360,11 +377,22 @@ private Map<String, Object> getAttributes(
360377 return attributes ;
361378 }
362379
380+ @ VisibleForTesting
381+ protected Collection <BatchProcessor > getBatchProcessors () {
382+ return batchProcessors .values ();
383+ }
384+
363385 @ Override
364386 public Completable close () {
365- if (batchProcessor != null ) {
366- batchProcessor .close ();
387+ for (BatchProcessor processor : getBatchProcessors ()) {
388+ processor .close ();
389+ }
390+ for (TraceManager traceManager : traceManagers .values ()) {
391+ traceManager .clearStack ();
367392 }
393+ batchProcessors .clear ();
394+ traceManagers .clear ();
395+
368396 if (writeClient != null ) {
369397 writeClient .close ();
370398 }
@@ -381,6 +409,7 @@ public Completable close() {
381409 }
382410
383411 private Optional <EventData > getCompletedEventData (InvocationContext invocationContext ) {
412+ TraceManager traceManager = getTraceManager (invocationContext .invocationId ());
384413 String traceId = traceManager .getTraceId (invocationContext );
385414 // Pop the invocation span from the trace manager.
386415 Optional <RecordData > popped = traceManager .popSpan ();
@@ -413,7 +442,7 @@ public Maybe<Content> onUserMessageCallback(
413442 InvocationContext invocationContext , Content userMessage ) {
414443 return Maybe .fromAction (
415444 () -> {
416- traceManager .ensureInvocationSpan (invocationContext );
445+ getTraceManager ( invocationContext . invocationId ()) .ensureInvocationSpan (invocationContext );
417446 logEvent ("USER_MESSAGE_RECEIVED" , invocationContext , userMessage , Optional .empty ());
418447 if (userMessage .parts ().isPresent ()) {
419448 for (Part part : userMessage .parts ().get ()) {
@@ -497,11 +526,16 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
497526
498527 @ Override
499528 public Maybe <Content > beforeRunCallback (InvocationContext invocationContext ) {
500- traceManager .ensureInvocationSpan (invocationContext );
529+ getTraceManager ( invocationContext . invocationId ()) .ensureInvocationSpan (invocationContext );
501530 return Maybe .fromAction (
502531 () -> logEvent ("INVOCATION_STARTING" , invocationContext , null , Optional .empty ()));
503532 }
504533
534+ @ VisibleForTesting
535+ protected BatchProcessor removeProcessor (String invocationId ) {
536+ return batchProcessors .remove (invocationId );
537+ }
538+
505539 @ Override
506540 public Completable afterRunCallback (InvocationContext invocationContext ) {
507541 return Completable .fromAction (
@@ -511,16 +545,24 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
511545 invocationContext ,
512546 null ,
513547 getCompletedEventData (invocationContext ));
514- batchProcessor .flush ();
515- traceManager .clearStack ();
548+ BatchProcessor processor = removeProcessor (invocationContext .invocationId ());
549+ if (processor != null ) {
550+ processor .flush ();
551+ processor .close ();
552+ }
553+ TraceManager traceManager = traceManagers .remove (invocationContext .invocationId ());
554+ if (traceManager != null ) {
555+ traceManager .clearStack ();
556+ }
516557 });
517558 }
518559
519560 @ Override
520561 public Maybe <Content > beforeAgentCallback (BaseAgent agent , CallbackContext callbackContext ) {
521562 return Maybe .fromAction (
522563 () -> {
523- traceManager .pushSpan ("agent:" + agent .name ());
564+ getTraceManager (callbackContext .invocationContext ().invocationId ())
565+ .pushSpan ("agent:" + agent .name ());
524566 logEvent ("AGENT_STARTING" , callbackContext .invocationContext (), null , Optional .empty ());
525567 });
526568 }
@@ -609,7 +651,8 @@ public Maybe<LlmResponse> beforeModelCallback(
609651 .setModel (req .model ().orElse ("" ))
610652 .setExtraAttributes (attributes )
611653 .build ();
612- traceManager .pushSpan ("llm_request" );
654+ getTraceManager (callbackContext .invocationContext ().invocationId ())
655+ .pushSpan ("llm_request" );
613656 logEvent ("LLM_REQUEST" , callbackContext .invocationContext (), req , Optional .of (eventData ));
614657 });
615658 }
@@ -619,6 +662,8 @@ public Maybe<LlmResponse> afterModelCallback(
619662 CallbackContext callbackContext , LlmResponse llmResponse ) {
620663 return Maybe .fromAction (
621664 () -> {
665+ TraceManager traceManager =
666+ getTraceManager (callbackContext .invocationContext ().invocationId ());
622667 // TODO(b/495809488): Add formatting of the content
623668 ParsedContent parsedContent =
624669 JsonFormatter .parse (llmResponse .content ().orElse (null ), config .maxContentLength ());
@@ -715,6 +760,8 @@ public Maybe<LlmResponse> onModelErrorCallback(
715760 CallbackContext callbackContext , LlmRequest .Builder llmRequest , Throwable error ) {
716761 return Maybe .fromAction (
717762 () -> {
763+ TraceManager traceManager =
764+ getTraceManager (callbackContext .invocationContext ().invocationId ());
718765 InvocationContext invocationContext = callbackContext .invocationContext ();
719766 Optional <RecordData > popped = traceManager .popSpan ();
720767 String spanId = popped .map (RecordData ::spanId ).orElse (null );
@@ -749,7 +796,7 @@ public Maybe<Map<String, Object>> beforeToolCallback(
749796 ImmutableMap <String , Object > contentMap =
750797 ImmutableMap .of (
751798 "tool_origin" , getToolOrigin (tool ), "tool" , tool .name (), "args" , res .node ());
752- traceManager .pushSpan ("tool" );
799+ getTraceManager ( toolContext . invocationContext (). invocationId ()) .pushSpan ("tool" );
753800 logEvent ("TOOL_STARTING" , toolContext .invocationContext (), contentMap , Optional .empty ());
754801 });
755802 }
@@ -762,6 +809,8 @@ public Maybe<Map<String, Object>> afterToolCallback(
762809 Map <String , Object > result ) {
763810 return Maybe .fromAction (
764811 () -> {
812+ TraceManager traceManager =
813+ getTraceManager (toolContext .invocationContext ().invocationId ());
765814 Optional <RecordData > popped = traceManager .popSpan ();
766815 TruncationResult truncationResult = smartTruncate (result , config .maxContentLength ());
767816 ImmutableMap <String , Object > contentMap =
@@ -799,6 +848,8 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
799848 BaseTool tool , Map <String , Object > toolArgs , ToolContext toolContext , Throwable error ) {
800849 return Maybe .fromAction (
801850 () -> {
851+ TraceManager traceManager =
852+ getTraceManager (toolContext .invocationContext ().invocationId ());
802853 Optional <RecordData > popped = traceManager .popSpan ();
803854 TruncationResult truncationResult = smartTruncate (toolArgs , config .maxContentLength ());
804855 String toolOrigin = getToolOrigin (tool );
0 commit comments