Skip to content

Commit a364e30

Browse files
google-genai-botcopybara-github
authored andcommitted
test: Adding tests to make sure that tracing works across threads
This testing found an improvement in how span propogation works in Function Calling. I would like to do a significant refactoring to cleanup all of the sprinkled tracing code. This step is necessary to confirm proper behavior before the refactor. PiperOrigin-RevId: 868623517
1 parent 4f2b5de commit a364e30

3 files changed

Lines changed: 360 additions & 118 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -246,34 +246,41 @@ private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
246246
Map<String, ToolConfirmation> toolConfirmations,
247247
boolean isLive) {
248248
Context parentContext = Context.current();
249-
return functionCall -> {
250-
BaseTool tool = tools.get(functionCall.name().get());
251-
ToolContext toolContext =
252-
ToolContext.builder(invocationContext)
253-
.functionCallId(functionCall.id().orElse(""))
254-
.toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null))
255-
.build();
256-
257-
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
258-
259-
Maybe<Map<String, Object>> maybeFunctionResult =
260-
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
261-
.switchIfEmpty(
262-
Maybe.defer(
263-
() -> {
264-
try (Scope scope = parentContext.makeCurrent()) {
265-
return isLive
266-
? processFunctionLive(
267-
invocationContext, tool, toolContext, functionCall, functionArgs)
268-
: callTool(tool, functionArgs, toolContext);
269-
}
270-
}));
271-
272-
try (Scope scope = parentContext.makeCurrent()) {
273-
return postProcessFunctionResult(
274-
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive);
275-
}
276-
};
249+
return functionCall ->
250+
Maybe.defer(
251+
() -> {
252+
try (Scope scope = parentContext.makeCurrent()) {
253+
BaseTool tool = tools.get(functionCall.name().get());
254+
ToolContext toolContext =
255+
ToolContext.builder(invocationContext)
256+
.functionCallId(functionCall.id().orElse(""))
257+
.toolConfirmation(
258+
functionCall.id().map(toolConfirmations::get).orElse(null))
259+
.build();
260+
261+
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
262+
263+
Maybe<Map<String, Object>> maybeFunctionResult =
264+
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
265+
.switchIfEmpty(
266+
Maybe.defer(
267+
() -> {
268+
try (Scope innerScope = parentContext.makeCurrent()) {
269+
return isLive
270+
? processFunctionLive(
271+
invocationContext,
272+
tool,
273+
toolContext,
274+
functionCall,
275+
functionArgs)
276+
: callTool(tool, functionArgs, toolContext);
277+
}
278+
}));
279+
280+
return postProcessFunctionResult(
281+
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext);
282+
}
283+
});
277284
}
278285

279286
/**
@@ -376,42 +383,49 @@ private static Maybe<Event> postProcessFunctionResult(
376383
InvocationContext invocationContext,
377384
BaseTool tool,
378385
Map<String, Object> functionArgs,
379-
ToolContext toolContext,
380-
boolean isLive) {
386+
ToolContext toolContext) {
381387
Context parentContext = Context.current();
382388
return maybeFunctionResult
383389
.map(Optional::of)
384390
.defaultIfEmpty(Optional.empty())
385391
.onErrorResumeNext(
386392
t ->
387-
handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t)
388-
.map(isLive ? Optional::ofNullable : Optional::of)
393+
Maybe.defer(
394+
() -> {
395+
try (Scope scope = parentContext.makeCurrent()) {
396+
return handleOnToolErrorCallback(
397+
invocationContext, tool, functionArgs, toolContext, t);
398+
}
399+
})
400+
.map(Optional::ofNullable)
389401
.switchIfEmpty(Single.error(t)))
390402
.flatMapMaybe(
391403
optionalInitialResult -> {
392-
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);
393-
394-
Maybe<Map<String, Object>> afterToolResultMaybe =
395-
maybeInvokeAfterToolCall(
396-
invocationContext, tool, functionArgs, toolContext, initialFunctionResult);
397-
398-
return afterToolResultMaybe
399-
.map(Optional::of)
400-
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
401-
.flatMapMaybe(
402-
finalOptionalResult -> {
403-
try (Scope scope = parentContext.makeCurrent()) {
404-
Map<String, Object> finalFunctionResult =
405-
finalOptionalResult.orElse(null);
406-
if (tool.longRunning() && finalFunctionResult == null) {
407-
return Maybe.empty();
404+
try (Scope scope = parentContext.makeCurrent()) {
405+
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);
406+
407+
Maybe<Map<String, Object>> afterToolResultMaybe =
408+
maybeInvokeAfterToolCall(
409+
invocationContext, tool, functionArgs, toolContext, initialFunctionResult);
410+
411+
return afterToolResultMaybe
412+
.map(Optional::of)
413+
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
414+
.flatMapMaybe(
415+
finalOptionalResult -> {
416+
try (Scope innerScope = parentContext.makeCurrent()) {
417+
Map<String, Object> finalFunctionResult =
418+
finalOptionalResult.orElse(null);
419+
if (tool.longRunning() && finalFunctionResult == null) {
420+
return Maybe.empty();
421+
}
422+
Event functionResponseEvent =
423+
buildResponseEvent(
424+
tool, finalFunctionResult, toolContext, invocationContext);
425+
return Maybe.just(functionResponseEvent);
408426
}
409-
Event functionResponseEvent =
410-
buildResponseEvent(
411-
tool, finalFunctionResult, toolContext, invocationContext);
412-
return Maybe.just(functionResponseEvent);
413-
}
414-
});
427+
});
428+
}
415429
});
416430
}
417431

core/src/test/java/com/google/adk/agents/LlmAgentTest.java

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import static com.google.adk.testing.TestUtils.createTestAgent;
2323
import static com.google.adk.testing.TestUtils.createTestAgentBuilder;
2424
import static com.google.adk.testing.TestUtils.createTestLlm;
25+
import static com.google.adk.testing.TestUtils.createTextLlmResponse;
2526
import static com.google.common.collect.Iterables.getOnlyElement;
2627
import static com.google.common.truth.Truth.assertThat;
28+
import static org.junit.Assert.assertEquals;
2729
import static org.junit.Assert.assertThrows;
2830

2931
import com.google.adk.agents.Callbacks.AfterModelCallback;
@@ -39,30 +41,53 @@
3941
import com.google.adk.models.Model;
4042
import com.google.adk.sessions.InMemorySessionService;
4143
import com.google.adk.sessions.Session;
44+
import com.google.adk.telemetry.Tracing;
4245
import com.google.adk.testing.TestLlm;
4346
import com.google.adk.testing.TestUtils.EchoTool;
4447
import com.google.adk.tools.BaseTool;
4548
import com.google.adk.tools.BaseToolset;
4649
import com.google.common.collect.ImmutableList;
4750
import com.google.common.collect.ImmutableMap;
51+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
4852
import com.google.genai.types.Content;
4953
import com.google.genai.types.FunctionDeclaration;
5054
import com.google.genai.types.Part;
5155
import com.google.genai.types.Schema;
56+
import io.opentelemetry.api.trace.Span;
57+
import io.opentelemetry.api.trace.Tracer;
58+
import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule;
59+
import io.opentelemetry.sdk.trace.data.SpanData;
5260
import io.reactivex.rxjava3.core.Flowable;
5361
import io.reactivex.rxjava3.core.Maybe;
5462
import io.reactivex.rxjava3.core.Single;
5563
import java.util.List;
5664
import java.util.Optional;
5765
import java.util.concurrent.ConcurrentHashMap;
5866
import java.util.concurrent.atomic.AtomicBoolean;
67+
import org.junit.After;
68+
import org.junit.Before;
69+
import org.junit.Rule;
5970
import org.junit.Test;
6071
import org.junit.runner.RunWith;
6172
import org.junit.runners.JUnit4;
6273

6374
/** Unit tests for {@link LlmAgent}. */
6475
@RunWith(JUnit4.class)
6576
public final class LlmAgentTest {
77+
@Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create();
78+
79+
private Tracer originalTracer;
80+
81+
@Before
82+
public void setup() {
83+
this.originalTracer = Tracing.getTracer();
84+
Tracing.setTracerForTesting(openTelemetryRule.getOpenTelemetry().getTracer("gcp.vertex.agent"));
85+
}
86+
87+
@After
88+
public void tearDown() {
89+
Tracing.setTracerForTesting(originalTracer);
90+
}
6691

6792
private static class ClosableToolset implements BaseToolset {
6893
final AtomicBoolean closed = new AtomicBoolean(false);
@@ -496,4 +521,121 @@ public void close() {
496521
assertThat(toolset1.closed.get()).isTrue();
497522
assertThat(toolset2.closed.get()).isTrue();
498523
}
524+
525+
@Test
526+
public void runAsync_createsInvokeAgentSpan() throws InterruptedException {
527+
Content modelContent = Content.fromParts(Part.fromText("response"));
528+
TestLlm testLlm = createTestLlm(createLlmResponse(modelContent));
529+
LlmAgent agent = createTestAgent(testLlm);
530+
InvocationContext invocationContext = createInvocationContext(agent);
531+
532+
agent.runAsync(invocationContext).test().await().assertComplete();
533+
534+
List<SpanData> spans = openTelemetryRule.getSpans();
535+
assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent")))
536+
.isTrue();
537+
}
538+
539+
@Test
540+
public void runAsync_withTools_createsToolSpans() throws InterruptedException {
541+
ImmutableMap<String, Object> echoArgs = ImmutableMap.of("arg", "value");
542+
Content contentWithFunctionCall =
543+
Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs));
544+
Content finalResponse = Content.fromParts(Part.fromText("finished"));
545+
TestLlm testLlm =
546+
createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse));
547+
LlmAgent agent = createTestAgentBuilder(testLlm).tools(new EchoTool()).build();
548+
InvocationContext invocationContext = createInvocationContext(agent);
549+
550+
agent.runAsync(invocationContext).test().await().assertComplete();
551+
552+
List<SpanData> spans = openTelemetryRule.getSpans();
553+
SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent");
554+
List<SpanData> llmSpans = findSpansByName(spans, "call_llm");
555+
List<SpanData> toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]");
556+
List<SpanData> toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]");
557+
558+
assertThat(llmSpans).hasSize(2);
559+
assertThat(toolCallSpans).hasSize(1);
560+
assertThat(toolResponseSpans).hasSize(1);
561+
562+
String agentSpanId = agentSpan.getSpanContext().getSpanId();
563+
llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId()));
564+
toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId()));
565+
toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId()));
566+
}
567+
568+
@Test
569+
public void runAsync_afterToolCallback_propagatesContext() throws InterruptedException {
570+
ImmutableMap<String, Object> echoArgs = ImmutableMap.of("arg", "value");
571+
Content contentWithFunctionCall =
572+
Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs));
573+
Content finalResponse = Content.fromParts(Part.fromText("finished"));
574+
TestLlm testLlm =
575+
createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse));
576+
577+
AfterToolCallback afterToolCallback =
578+
(invCtx, tool, input, toolCtx, response) -> {
579+
// Verify that the OpenTelemetry context is correctly propagated to the callback.
580+
assertThat(Span.current().getSpanContext().isValid()).isTrue();
581+
return Maybe.empty();
582+
};
583+
584+
LlmAgent agent =
585+
createTestAgentBuilder(testLlm)
586+
.tools(new EchoTool())
587+
.afterToolCallback(ImmutableList.of(afterToolCallback))
588+
.build();
589+
InvocationContext invocationContext = createInvocationContext(agent);
590+
591+
agent.runAsync(invocationContext).test().await().assertComplete();
592+
593+
List<SpanData> spans = openTelemetryRule.getSpans();
594+
findSpanByName(spans, "invoke_agent test agent");
595+
}
596+
597+
@Test
598+
public void runAsync_withSubAgents_createsSpans() throws InterruptedException {
599+
LlmAgent subAgent =
600+
createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response")))
601+
.name("sub-agent")
602+
.build();
603+
604+
// Force a transfer to sub-agent using a callback
605+
AfterModelCallback transferCallback =
606+
(ctx, response) -> {
607+
ctx.eventActions().setTransferToAgent(subAgent.name());
608+
return Maybe.empty();
609+
};
610+
611+
TestLlm testLlm = createTestLlm(createTextLlmResponse("initial"));
612+
LlmAgent agent =
613+
createTestAgentBuilder(testLlm)
614+
.subAgents(subAgent)
615+
.afterModelCallback(ImmutableList.of(transferCallback))
616+
.build();
617+
InvocationContext invocationContext = createInvocationContext(agent);
618+
619+
agent.runAsync(invocationContext).test().await().assertComplete();
620+
621+
List<SpanData> spans = openTelemetryRule.getSpans();
622+
assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent")))
623+
.isTrue();
624+
assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent sub-agent"))).isTrue();
625+
626+
List<SpanData> llmSpans = findSpansByName(spans, "call_llm");
627+
assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent
628+
}
629+
630+
private List<SpanData> findSpansByName(List<SpanData> spans, String name) {
631+
return spans.stream().filter(s -> s.getName().equals(name)).toList();
632+
}
633+
634+
@CanIgnoreReturnValue
635+
private SpanData findSpanByName(List<SpanData> spans, String name) {
636+
return spans.stream()
637+
.filter(s -> s.getName().equals(name))
638+
.findFirst()
639+
.orElseThrow(() -> new AssertionError("Span not found: " + name));
640+
}
499641
}

0 commit comments

Comments
 (0)