|
35 | 35 | import static org.mockito.Mockito.when; |
36 | 36 |
|
37 | 37 | import com.google.adk.agents.BaseAgent; |
| 38 | +import com.google.adk.agents.CallbackContext; |
38 | 39 | import com.google.adk.agents.InvocationContext; |
39 | 40 | import com.google.adk.agents.LiveRequestQueue; |
40 | 41 | import com.google.adk.agents.LlmAgent; |
@@ -614,6 +615,77 @@ public void callbackContextData_preservedAcrossInvocation() { |
614 | 615 | assertThat(contextCaptor.getValue().callbackContextData()).containsEntry(testKey, testValue); |
615 | 616 | } |
616 | 617 |
|
| 618 | + @Test |
| 619 | + public void runAsync_duringMultiTurnExecution_emittedEventsAreVisibleInSubsequentTurn() { |
| 620 | + // Setup LLM to return a function call, and then a final response |
| 621 | + TestLlm testLlmForRace = |
| 622 | + createTestLlm( |
| 623 | + createLlmResponse( |
| 624 | + Content.builder() |
| 625 | + .role("model") |
| 626 | + .parts( |
| 627 | + Part.builder() |
| 628 | + .functionCall( |
| 629 | + FunctionCall.builder() |
| 630 | + .name(echoTool.name()) |
| 631 | + .args(ImmutableMap.of("args_name", "args_value")) |
| 632 | + .build()) |
| 633 | + .build()) |
| 634 | + .build()), |
| 635 | + createLlmResponse(createContent("done"))); |
| 636 | + |
| 637 | + LlmAgent agentForRace = |
| 638 | + createTestAgentBuilder(testLlmForRace).tools(ImmutableList.of(echoTool)).build(); |
| 639 | + |
| 640 | + Runner runnerForRace = |
| 641 | + Runner.builder() |
| 642 | + .app( |
| 643 | + App.builder() |
| 644 | + .name("test") |
| 645 | + .rootAgent(agentForRace) |
| 646 | + .plugins(ImmutableList.of(plugin)) |
| 647 | + .build()) |
| 648 | + .build(); |
| 649 | + |
| 650 | + Session sessionForRace = |
| 651 | + runnerForRace.sessionService().createSession("test", "user").blockingGet(); |
| 652 | + |
| 653 | + // Use a mock plugin to check session events in beforeModelCallback |
| 654 | + // It should be called for the second turn (after the function call) |
| 655 | + AtomicInteger callCount = new AtomicInteger(0); |
| 656 | + when(plugin.beforeModelCallback(any(), any())) |
| 657 | + .thenAnswer( |
| 658 | + invocation -> { |
| 659 | + CallbackContext context = invocation.getArgument(0); |
| 660 | + int count = callCount.incrementAndGet(); |
| 661 | + if (count == 2) { |
| 662 | + // This is the second turn, after the function call |
| 663 | + // Check if the session contains the function call event |
| 664 | + List<Event> events = context.events(); |
| 665 | + boolean hasFunctionCall = |
| 666 | + events.stream() |
| 667 | + .flatMap( |
| 668 | + e -> |
| 669 | + e |
| 670 | + .content() |
| 671 | + .flatMap(Content::parts) |
| 672 | + .orElse(ImmutableList.of()) |
| 673 | + .stream()) |
| 674 | + .anyMatch(p -> p.functionCall().isPresent()); |
| 675 | + assertThat(hasFunctionCall).isTrue(); |
| 676 | + } |
| 677 | + return Maybe.empty(); |
| 678 | + }); |
| 679 | + |
| 680 | + var unused = |
| 681 | + runnerForRace |
| 682 | + .runAsync("user", sessionForRace.id(), createContent("start")) |
| 683 | + .toList() |
| 684 | + .blockingGet(); |
| 685 | + |
| 686 | + assertThat(callCount.get()).isEqualTo(2); |
| 687 | + } |
| 688 | + |
617 | 689 | @Test |
618 | 690 | public void runAsync_withSessionKey_success() { |
619 | 691 | var events = |
|
0 commit comments