Skip to content

Commit 7486cb5

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Make BigQueryAgentAnalyticsPlugin state per-invocation
This change introduces per-invocation instances of BatchProcessor and TraceManager, managed by ConcurrentHashMaps keyed by invocation ID. This ensures that analytics and tracing data are isolated for each concurrent invocation. BatchProcessors and TraceManagers are created lazily on the first event for a given invocation and are cleaned up when the invocation completes. PiperOrigin-RevId: 897370846
1 parent 14027d1 commit 7486cb5

4 files changed

Lines changed: 354 additions & 146 deletions

File tree

core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java

Lines changed: 48 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode;
2222
import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate;
2323
import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject;
24-
import static java.util.concurrent.TimeUnit.MILLISECONDS;
2524

2625
import com.google.adk.agents.BaseAgent;
2726
import com.google.adk.agents.CallbackContext;
@@ -41,8 +40,6 @@
4140
import com.google.adk.tools.ToolContext;
4241
import com.google.adk.tools.mcp.AbstractMcpTool;
4342
import com.google.adk.utils.AgentEnums.AgentOrigin;
44-
import com.google.api.gax.core.FixedCredentialsProvider;
45-
import com.google.api.gax.retrying.RetrySettings;
4643
import com.google.auth.oauth2.GoogleCredentials;
4744
import com.google.cloud.bigquery.BigQuery;
4845
import com.google.cloud.bigquery.BigQueryException;
@@ -53,11 +50,7 @@
5350
import com.google.cloud.bigquery.Table;
5451
import com.google.cloud.bigquery.TableId;
5552
import com.google.cloud.bigquery.TableInfo;
56-
import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient;
57-
import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings;
58-
import com.google.cloud.bigquery.storage.v1.StreamWriter;
5953
import com.google.common.annotations.VisibleForTesting;
60-
import com.google.common.base.VerifyException;
6154
import com.google.common.collect.ImmutableList;
6255
import com.google.common.collect.ImmutableMap;
6356
import com.google.genai.types.Content;
@@ -70,9 +63,6 @@
7063
import java.util.HashMap;
7164
import java.util.Map;
7265
import java.util.Optional;
73-
import java.util.concurrent.Executors;
74-
import java.util.concurrent.ScheduledExecutorService;
75-
import java.util.concurrent.ThreadFactory;
7666
import java.util.concurrent.atomic.AtomicLong;
7767
import java.util.logging.Level;
7868
import java.util.logging.Logger;
@@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
10090

10191
private final BigQueryLoggerConfig config;
10292
private final BigQuery bigQuery;
103-
private final BigQueryWriteClient writeClient;
104-
private final ScheduledExecutorService executor;
10593
private final Object tableEnsuredLock = new Object();
106-
@VisibleForTesting final BatchProcessor batchProcessor;
107-
@VisibleForTesting final TraceManager traceManager;
94+
private final PluginState state;
10895
private volatile boolean tableEnsured = false;
10996

11097
public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException {
@@ -113,28 +100,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept
113100

114101
public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery)
115102
throws IOException {
103+
this(config, bigQuery, new PluginState(config));
104+
}
105+
106+
BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) {
116107
super("bigquery_agent_analytics");
117108
this.config = config;
118109
this.bigQuery = bigQuery;
119-
ThreadFactory threadFactory =
120-
r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement());
121-
this.executor = Executors.newScheduledThreadPool(1, threadFactory);
122-
this.writeClient = createWriteClient(config);
123-
this.traceManager = createTraceManager();
124-
125-
if (config.enabled()) {
126-
StreamWriter writer = createWriter(config);
127-
this.batchProcessor =
128-
new BatchProcessor(
129-
writer,
130-
config.batchSize(),
131-
config.batchFlushInterval(),
132-
config.queueMaxSize(),
133-
executor);
134-
this.batchProcessor.start();
135-
} else {
136-
this.batchProcessor = null;
137-
}
110+
this.state = state;
138111
}
139112

140113
private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException {
@@ -194,7 +167,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) {
194167

195168
try {
196169
if (config.createViews()) {
197-
var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config));
170+
var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config));
198171
}
199172
} catch (RuntimeException e) {
200173
logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e);
@@ -209,48 +182,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) {
209182
}
210183
}
211184

212-
protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException {
213-
if (config.credentials() != null) {
214-
return BigQueryWriteClient.create(
215-
BigQueryWriteSettings.newBuilder()
216-
.setCredentialsProvider(FixedCredentialsProvider.create(config.credentials()))
217-
.build());
218-
}
219-
return BigQueryWriteClient.create();
220-
}
221-
222-
protected String getStreamName(BigQueryLoggerConfig config) {
223-
return String.format(
224-
"projects/%s/datasets/%s/tables/%s/streams/_default",
225-
config.projectId(), config.datasetId(), config.tableName());
226-
}
227-
228-
protected StreamWriter createWriter(BigQueryLoggerConfig config) {
229-
BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig();
230-
RetrySettings retrySettings =
231-
RetrySettings.newBuilder()
232-
.setMaxAttempts(retryConfig.maxRetries())
233-
.setInitialRetryDelay(
234-
org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis()))
235-
.setRetryDelayMultiplier(retryConfig.multiplier())
236-
.setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis()))
237-
.build();
238-
239-
String streamName = getStreamName(config);
240-
try {
241-
return StreamWriter.newBuilder(streamName, writeClient)
242-
.setRetrySettings(retrySettings)
243-
.setWriterSchema(BigQuerySchema.getArrowSchema())
244-
.build();
245-
} catch (Exception e) {
246-
throw new VerifyException("Failed to create StreamWriter for " + streamName, e);
247-
}
248-
}
249-
250-
protected TraceManager createTraceManager() {
251-
return new TraceManager();
252-
}
253-
254185
private void logEvent(
255186
String eventType,
256187
InvocationContext invocationContext,
@@ -265,7 +196,7 @@ private void logEvent(
265196
Object content,
266197
boolean isContentTruncated,
267198
Optional<EventData> eventData) {
268-
if (!config.enabled() || batchProcessor == null) {
199+
if (!config.enabled()) {
269200
return;
270201
}
271202
if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) {
@@ -274,6 +205,8 @@ private void logEvent(
274205
if (config.eventDenylist().contains(eventType)) {
275206
return;
276207
}
208+
String invocationId = invocationContext.invocationId();
209+
BatchProcessor processor = state.getBatchProcessor(invocationId);
277210
// Ensure table exists before logging.
278211
ensureTableExistsOnce();
279212
// Log common fields
@@ -301,11 +234,12 @@ private void logEvent(
301234
row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext)));
302235

303236
addTraceDetails(row, invocationContext, eventData);
304-
batchProcessor.append(row);
237+
processor.append(row);
305238
}
306239

307240
private void addTraceDetails(
308241
Map<String, Object> row, InvocationContext invocationContext, Optional<EventData> eventData) {
242+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
309243
String traceId =
310244
eventData
311245
.flatMap(EventData::traceIdOverride)
@@ -336,7 +270,7 @@ private void addTraceDetails(
336270
private Map<String, Object> getAttributes(
337271
EventData eventData, InvocationContext invocationContext) {
338272
Map<String, Object> attributes = new HashMap<>(eventData.extraAttributes());
339-
273+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
340274
attributes.put("root_agent_name", traceManager.getRootAgentName());
341275
eventData.model().ifPresent(m -> attributes.put("model", m));
342276
eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv));
@@ -375,25 +309,17 @@ private Map<String, Object> getAttributes(
375309

376310
@Override
377311
public Completable close() {
378-
if (batchProcessor != null) {
379-
batchProcessor.close();
380-
}
381-
if (writeClient != null) {
382-
writeClient.close();
383-
}
384-
try {
385-
executor.shutdown();
386-
if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) {
387-
executor.shutdownNow();
388-
}
389-
} catch (InterruptedException e) {
390-
executor.shutdownNow();
391-
Thread.currentThread().interrupt();
392-
}
312+
state.close();
393313
return Completable.complete();
394314
}
395315

316+
@VisibleForTesting
317+
PluginState getState() {
318+
return state;
319+
}
320+
396321
private Optional<EventData> getCompletedEventData(InvocationContext invocationContext) {
322+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
397323
String traceId = traceManager.getTraceId(invocationContext);
398324
// Pop the invocation span from the trace manager.
399325
Optional<RecordData> popped = traceManager.popSpan();
@@ -426,7 +352,9 @@ public Maybe<Content> onUserMessageCallback(
426352
InvocationContext invocationContext, Content userMessage) {
427353
return Maybe.fromAction(
428354
() -> {
429-
traceManager.ensureInvocationSpan(invocationContext);
355+
state
356+
.getTraceManager(invocationContext.invocationId())
357+
.ensureInvocationSpan(invocationContext);
430358
logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty());
431359
if (userMessage.parts().isPresent()) {
432360
for (Part part : userMessage.parts().get()) {
@@ -510,7 +438,7 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
510438

511439
@Override
512440
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
513-
traceManager.ensureInvocationSpan(invocationContext);
441+
state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext);
514442
return Maybe.fromAction(
515443
() -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()));
516444
}
@@ -524,16 +452,25 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
524452
invocationContext,
525453
null,
526454
getCompletedEventData(invocationContext));
527-
batchProcessor.flush();
528-
traceManager.clearStack();
455+
BatchProcessor processor = state.removeProcessor(invocationContext.invocationId());
456+
if (processor != null) {
457+
processor.flush();
458+
processor.close();
459+
}
460+
TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId());
461+
if (traceManager != null) {
462+
traceManager.clearStack();
463+
}
529464
});
530465
}
531466

532467
@Override
533468
public Maybe<Content> beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) {
534469
return Maybe.fromAction(
535470
() -> {
536-
traceManager.pushSpan("agent:" + agent.name());
471+
state
472+
.getTraceManager(callbackContext.invocationContext().invocationId())
473+
.pushSpan("agent:" + agent.name());
537474
logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty());
538475
});
539476
}
@@ -622,7 +559,9 @@ public Maybe<LlmResponse> beforeModelCallback(
622559
.setModel(req.model().orElse(""))
623560
.setExtraAttributes(attributes)
624561
.build();
625-
traceManager.pushSpan("llm_request");
562+
state
563+
.getTraceManager(callbackContext.invocationContext().invocationId())
564+
.pushSpan("llm_request");
626565
logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData));
627566
});
628567
}
@@ -632,6 +571,8 @@ public Maybe<LlmResponse> afterModelCallback(
632571
CallbackContext callbackContext, LlmResponse llmResponse) {
633572
return Maybe.fromAction(
634573
() -> {
574+
TraceManager traceManager =
575+
state.getTraceManager(callbackContext.invocationContext().invocationId());
635576
// TODO(b/495809488): Add formatting of the content
636577
ParsedContent parsedContent =
637578
JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength());
@@ -728,6 +669,8 @@ public Maybe<LlmResponse> onModelErrorCallback(
728669
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
729670
return Maybe.fromAction(
730671
() -> {
672+
TraceManager traceManager =
673+
state.getTraceManager(callbackContext.invocationContext().invocationId());
731674
InvocationContext invocationContext = callbackContext.invocationContext();
732675
Optional<RecordData> popped = traceManager.popSpan();
733676
String spanId = popped.map(RecordData::spanId).orElse(null);
@@ -762,7 +705,7 @@ public Maybe<Map<String, Object>> beforeToolCallback(
762705
ImmutableMap<String, Object> contentMap =
763706
ImmutableMap.of(
764707
"tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node());
765-
traceManager.pushSpan("tool");
708+
state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool");
766709
logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty());
767710
});
768711
}
@@ -775,6 +718,8 @@ public Maybe<Map<String, Object>> afterToolCallback(
775718
Map<String, Object> result) {
776719
return Maybe.fromAction(
777720
() -> {
721+
TraceManager traceManager =
722+
state.getTraceManager(toolContext.invocationContext().invocationId());
778723
Optional<RecordData> popped = traceManager.popSpan();
779724
TruncationResult truncationResult = smartTruncate(result, config.maxContentLength());
780725
ImmutableMap<String, Object> contentMap =
@@ -812,6 +757,8 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
812757
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext, Throwable error) {
813758
return Maybe.fromAction(
814759
() -> {
760+
TraceManager traceManager =
761+
state.getTraceManager(toolContext.invocationContext().invocationId());
815762
Optional<RecordData> popped = traceManager.popSpan();
816763
TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength());
817764
String toolOrigin = getToolOrigin(tool);

0 commit comments

Comments
 (0)