Skip to content

Commit 815efa9

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Make BigQueryAgentAnalyticsPlugin state per-invocation
This change introduces per-invocation instances of BatchProcessor and TraceManager, managed by ConcurrentHashMaps keyed by invocation ID. This ensures that analytics and tracing data are isolated for each concurrent invocation. BatchProcessors and TraceManagers are created lazily on the first event for a given invocation and are cleaned up when the invocation completes. PiperOrigin-RevId: 897370846
1 parent 8ef99f9 commit 815efa9

3 files changed

Lines changed: 213 additions & 68 deletions

File tree

core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java

Lines changed: 84 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@
6565
import java.io.IOException;
6666
import java.time.Duration;
6767
import java.time.Instant;
68+
import java.util.Collection;
6869
import java.util.HashMap;
6970
import java.util.Map;
7071
import java.util.Optional;
72+
import java.util.concurrent.ConcurrentHashMap;
7173
import java.util.concurrent.Executors;
7274
import java.util.concurrent.ScheduledExecutorService;
7375
import java.util.concurrent.ThreadFactory;
@@ -101,8 +103,11 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
101103
private final BigQueryWriteClient writeClient;
102104
private final ScheduledExecutorService executor;
103105
private final Object tableEnsuredLock = new Object();
104-
@VisibleForTesting final BatchProcessor batchProcessor;
105-
@VisibleForTesting final TraceManager traceManager;
106+
// Map of invocation ID to BatchProcessor.
107+
private final ConcurrentHashMap<String, BatchProcessor> batchProcessors =
108+
new ConcurrentHashMap<>();
109+
// Map of invocation ID to TraceManager.
110+
private final ConcurrentHashMap<String, TraceManager> traceManagers = new ConcurrentHashMap<>();
106111
private volatile boolean tableEnsured = false;
107112

108113
public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException {
@@ -118,21 +123,16 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQue
118123
r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement());
119124
this.executor = Executors.newScheduledThreadPool(1, threadFactory);
120125
this.writeClient = createWriteClient(config);
121-
this.traceManager = createTraceManager();
122-
123-
if (config.enabled()) {
124-
StreamWriter writer = createWriter(config);
125-
this.batchProcessor =
126-
new BatchProcessor(
127-
writer,
128-
config.batchSize(),
129-
config.batchFlushInterval(),
130-
config.queueMaxSize(),
131-
executor);
132-
this.batchProcessor.start();
133-
} else {
134-
this.batchProcessor = null;
135-
}
126+
}
127+
128+
@VisibleForTesting
129+
TraceManager getTraceManager(String invocationId) {
130+
return traceManagers.computeIfAbsent(invocationId, id -> new TraceManager());
131+
}
132+
133+
@VisibleForTesting
134+
BatchProcessor getBatchProcessor(String invocationId) {
135+
return batchProcessors.get(invocationId);
136136
}
137137

138138
private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException {
@@ -234,10 +234,6 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) {
234234
}
235235
}
236236

237-
protected TraceManager createTraceManager() {
238-
return new TraceManager();
239-
}
240-
241237
private void logEvent(
242238
String eventType,
243239
InvocationContext invocationContext,
@@ -252,7 +248,7 @@ private void logEvent(
252248
Object content,
253249
boolean isContentTruncated,
254250
Optional<EventData> eventData) {
255-
if (!config.enabled() || batchProcessor == null) {
251+
if (!config.enabled()) {
256252
return;
257253
}
258254
if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) {
@@ -261,6 +257,26 @@ private void logEvent(
261257
if (config.eventDenylist().contains(eventType)) {
262258
return;
263259
}
260+
String invocationId = invocationContext.invocationId();
261+
BatchProcessor processor =
262+
batchProcessors.computeIfAbsent(
263+
invocationId,
264+
id -> {
265+
StreamWriter writer = createWriter(config);
266+
BatchProcessor p =
267+
new BatchProcessor(
268+
writer,
269+
config.batchSize(),
270+
config.batchFlushInterval(),
271+
config.queueMaxSize(),
272+
executor);
273+
p.start();
274+
return p;
275+
});
276+
if (processor == null) {
277+
logger.severe("Failed to create BatchProcessor for invocationId: " + invocationId);
278+
return;
279+
}
264280
// Ensure table exists before logging.
265281
ensureTableExistsOnce();
266282
// Log common fields
@@ -288,11 +304,12 @@ private void logEvent(
288304
row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext)));
289305

290306
addTraceDetails(row, invocationContext, eventData);
291-
batchProcessor.append(row);
307+
processor.append(row);
292308
}
293309

294310
private void addTraceDetails(
295311
Map<String, Object> row, InvocationContext invocationContext, Optional<EventData> eventData) {
312+
TraceManager traceManager = getTraceManager(invocationContext.invocationId());
296313
String traceId =
297314
eventData
298315
.flatMap(EventData::traceIdOverride)
@@ -323,7 +340,7 @@ private void addTraceDetails(
323340
private Map<String, Object> getAttributes(
324341
EventData eventData, InvocationContext invocationContext) {
325342
Map<String, Object> attributes = new HashMap<>(eventData.extraAttributes());
326-
343+
TraceManager traceManager = getTraceManager(invocationContext.invocationId());
327344
attributes.put("root_agent_name", traceManager.getRootAgentName());
328345
eventData.model().ifPresent(m -> attributes.put("model", m));
329346
eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv));
@@ -360,11 +377,22 @@ private Map<String, Object> getAttributes(
360377
return attributes;
361378
}
362379

380+
@VisibleForTesting
381+
protected Collection<BatchProcessor> getBatchProcessors() {
382+
return batchProcessors.values();
383+
}
384+
363385
@Override
364386
public Completable close() {
365-
if (batchProcessor != null) {
366-
batchProcessor.close();
387+
for (BatchProcessor processor : getBatchProcessors()) {
388+
processor.close();
389+
}
390+
for (TraceManager traceManager : traceManagers.values()) {
391+
traceManager.clearStack();
367392
}
393+
batchProcessors.clear();
394+
traceManagers.clear();
395+
368396
if (writeClient != null) {
369397
writeClient.close();
370398
}
@@ -381,6 +409,7 @@ public Completable close() {
381409
}
382410

383411
private Optional<EventData> getCompletedEventData(InvocationContext invocationContext) {
412+
TraceManager traceManager = getTraceManager(invocationContext.invocationId());
384413
String traceId = traceManager.getTraceId(invocationContext);
385414
// Pop the invocation span from the trace manager.
386415
Optional<RecordData> popped = traceManager.popSpan();
@@ -413,7 +442,7 @@ public Maybe<Content> onUserMessageCallback(
413442
InvocationContext invocationContext, Content userMessage) {
414443
return Maybe.fromAction(
415444
() -> {
416-
traceManager.ensureInvocationSpan(invocationContext);
445+
getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext);
417446
logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty());
418447
if (userMessage.parts().isPresent()) {
419448
for (Part part : userMessage.parts().get()) {
@@ -497,11 +526,16 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
497526

498527
@Override
499528
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
500-
traceManager.ensureInvocationSpan(invocationContext);
529+
getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext);
501530
return Maybe.fromAction(
502531
() -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()));
503532
}
504533

534+
@VisibleForTesting
535+
protected BatchProcessor removeProcessor(String invocationId) {
536+
return batchProcessors.remove(invocationId);
537+
}
538+
505539
@Override
506540
public Completable afterRunCallback(InvocationContext invocationContext) {
507541
return Completable.fromAction(
@@ -511,16 +545,24 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
511545
invocationContext,
512546
null,
513547
getCompletedEventData(invocationContext));
514-
batchProcessor.flush();
515-
traceManager.clearStack();
548+
BatchProcessor processor = removeProcessor(invocationContext.invocationId());
549+
if (processor != null) {
550+
processor.flush();
551+
processor.close();
552+
}
553+
TraceManager traceManager = traceManagers.remove(invocationContext.invocationId());
554+
if (traceManager != null) {
555+
traceManager.clearStack();
556+
}
516557
});
517558
}
518559

519560
@Override
520561
public Maybe<Content> beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) {
521562
return Maybe.fromAction(
522563
() -> {
523-
traceManager.pushSpan("agent:" + agent.name());
564+
getTraceManager(callbackContext.invocationContext().invocationId())
565+
.pushSpan("agent:" + agent.name());
524566
logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty());
525567
});
526568
}
@@ -609,7 +651,8 @@ public Maybe<LlmResponse> beforeModelCallback(
609651
.setModel(req.model().orElse(""))
610652
.setExtraAttributes(attributes)
611653
.build();
612-
traceManager.pushSpan("llm_request");
654+
getTraceManager(callbackContext.invocationContext().invocationId())
655+
.pushSpan("llm_request");
613656
logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData));
614657
});
615658
}
@@ -619,6 +662,8 @@ public Maybe<LlmResponse> afterModelCallback(
619662
CallbackContext callbackContext, LlmResponse llmResponse) {
620663
return Maybe.fromAction(
621664
() -> {
665+
TraceManager traceManager =
666+
getTraceManager(callbackContext.invocationContext().invocationId());
622667
// TODO(b/495809488): Add formatting of the content
623668
ParsedContent parsedContent =
624669
JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength());
@@ -715,6 +760,8 @@ public Maybe<LlmResponse> onModelErrorCallback(
715760
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
716761
return Maybe.fromAction(
717762
() -> {
763+
TraceManager traceManager =
764+
getTraceManager(callbackContext.invocationContext().invocationId());
718765
InvocationContext invocationContext = callbackContext.invocationContext();
719766
Optional<RecordData> popped = traceManager.popSpan();
720767
String spanId = popped.map(RecordData::spanId).orElse(null);
@@ -749,7 +796,7 @@ public Maybe<Map<String, Object>> beforeToolCallback(
749796
ImmutableMap<String, Object> contentMap =
750797
ImmutableMap.of(
751798
"tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node());
752-
traceManager.pushSpan("tool");
799+
getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool");
753800
logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty());
754801
});
755802
}
@@ -762,6 +809,8 @@ public Maybe<Map<String, Object>> afterToolCallback(
762809
Map<String, Object> result) {
763810
return Maybe.fromAction(
764811
() -> {
812+
TraceManager traceManager =
813+
getTraceManager(toolContext.invocationContext().invocationId());
765814
Optional<RecordData> popped = traceManager.popSpan();
766815
TruncationResult truncationResult = smartTruncate(result, config.maxContentLength());
767816
ImmutableMap<String, Object> contentMap =
@@ -799,6 +848,8 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
799848
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext, Throwable error) {
800849
return Maybe.fromAction(
801850
() -> {
851+
TraceManager traceManager =
852+
getTraceManager(toolContext.invocationContext().invocationId());
802853
Optional<RecordData> popped = traceManager.popSpan();
803854
TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength());
804855
String toolOrigin = getToolOrigin(tool);

core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,22 @@ protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) {
103103
protected StreamWriter createWriter(BigQueryLoggerConfig config) {
104104
return mockWriter;
105105
}
106+
107+
// Override afterRunCallback to avoid removing batch processor in the middle of the test.
108+
@Override
109+
protected BatchProcessor removeProcessor(String invocationId) {
110+
return null;
111+
}
106112
};
107113

108114
when(mockWriter.append(any(ArrowRecordBatch.class)))
109115
.thenAnswer(
110116
invocation -> {
111117
ArrowRecordBatch recordedBatch = invocation.getArgument(0);
118+
BatchProcessor batchProcessor = plugin.getBatchProcessors().iterator().next();
112119
try (VectorSchemaRoot root =
113120
VectorSchemaRoot.create(
114-
BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) {
121+
BigQuerySchema.getArrowSchema(), batchProcessor.allocator)) {
115122
VectorLoader loader = new VectorLoader(root);
116123
loader.load(recordedBatch);
117124
for (int i = 0; i < root.getRowCount(); i++) {
@@ -150,8 +157,9 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception {
150157

151158
// Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes,
152159
// but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events.
160+
BatchProcessor batchProcessor = plugin.getBatchProcessors().iterator().next();
153161
for (int i = 0; i < 20 && capturedRows.size() < 5; i++) {
154-
plugin.batchProcessor.flush();
162+
batchProcessor.flush();
155163
if (capturedRows.size() < 5) {
156164
Thread.sleep(100);
157165
}

0 commit comments

Comments
 (0)