Skip to content

Commit 3091156

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fix critical race condition in ADK Runner
PiperOrigin-RevId: 895903521
1 parent d348a2a commit 3091156

2 files changed

Lines changed: 100 additions & 2 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import java.util.ArrayList;
6161
import java.util.Arrays;
6262
import java.util.Collections;
63+
import java.util.HashMap;
6364
import java.util.List;
6465
import java.util.Map;
6566
import java.util.Optional;
@@ -532,12 +533,34 @@ private Flowable<Event> runAgentWithUpdatedSession(
532533
contextWithUpdatedSession
533534
.agent()
534535
.runAsync(contextWithUpdatedSession)
536+
.map(
537+
agentEvent -> {
538+
// We create a temporary shallow copy of the session to pass to the persistence
539+
// service.
540+
// This copy is created BEFORE we add the agentEvent to the in-memory session.
541+
Session sessionForService =
542+
Session.builder(updatedSession.id())
543+
.appName(updatedSession.appName())
544+
.userId(updatedSession.userId())
545+
.state(new HashMap<>(updatedSession.state()))
546+
.events(new ArrayList<>(updatedSession.events()))
547+
.build();
548+
549+
// Unblock the in-memory session synchronously as soon as the event is emitted!
550+
// This allows the agent's internal loop (llmFlow) to see the event immediately
551+
// for its next turn without waiting for previous DB writes to complete.
552+
updatedSession.events().add(agentEvent);
553+
554+
return new EventWithSession(sessionForService, agentEvent);
555+
})
535556
.concatMap(
536-
agentEvent ->
557+
wrapper ->
537558
this.sessionService
538-
.appendEvent(updatedSession, agentEvent)
559+
.appendEvent(wrapper.sessionForService(), wrapper.event())
539560
.flatMap(
540561
registeredEvent -> {
562+
// Sync state changes back from isolated copy to our primary session
563+
copySessionStates(wrapper.sessionForService(), updatedSession);
541564
// TODO: remove this hack after deprecating runAsync with Session.
542565
copySessionStates(updatedSession, initialContext.session());
543566
return contextWithUpdatedSession
@@ -804,5 +827,8 @@ private static EventsCompactionConfig createEventsCompactionConfig(
804827
config.eventRetentionSize());
805828
}
806829

830+
/** A record to wrap the isolated session and the event for sequential persistence. */
831+
private static record EventWithSession(Session sessionForService, Event event) {}
832+
807833
// TODO: run statelessly
808834
}

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import static org.mockito.Mockito.when;
3636

3737
import com.google.adk.agents.BaseAgent;
38+
import com.google.adk.agents.CallbackContext;
3839
import com.google.adk.agents.InvocationContext;
3940
import com.google.adk.agents.LiveRequestQueue;
4041
import com.google.adk.agents.LlmAgent;
@@ -614,6 +615,77 @@ public void callbackContextData_preservedAcrossInvocation() {
614615
assertThat(contextCaptor.getValue().callbackContextData()).containsEntry(testKey, testValue);
615616
}
616617

618+
@Test
619+
public void runAsync_duringMultiTurnExecution_emittedEventsAreVisibleInSubsequentTurn() {
620+
// Setup LLM to return a function call, and then a final response
621+
TestLlm testLlmForRace =
622+
createTestLlm(
623+
createLlmResponse(
624+
Content.builder()
625+
.role("model")
626+
.parts(
627+
Part.builder()
628+
.functionCall(
629+
FunctionCall.builder()
630+
.name(echoTool.name())
631+
.args(ImmutableMap.of("args_name", "args_value"))
632+
.build())
633+
.build())
634+
.build()),
635+
createLlmResponse(createContent("done")));
636+
637+
LlmAgent agentForRace =
638+
createTestAgentBuilder(testLlmForRace).tools(ImmutableList.of(echoTool)).build();
639+
640+
Runner runnerForRace =
641+
Runner.builder()
642+
.app(
643+
App.builder()
644+
.name("test")
645+
.rootAgent(agentForRace)
646+
.plugins(ImmutableList.of(plugin))
647+
.build())
648+
.build();
649+
650+
Session sessionForRace =
651+
runnerForRace.sessionService().createSession("test", "user").blockingGet();
652+
653+
// Use a mock plugin to check session events in beforeModelCallback
654+
// It should be called for the second turn (after the function call)
655+
AtomicInteger callCount = new AtomicInteger(0);
656+
when(plugin.beforeModelCallback(any(), any()))
657+
.thenAnswer(
658+
invocation -> {
659+
CallbackContext context = invocation.getArgument(0);
660+
int count = callCount.incrementAndGet();
661+
if (count == 2) {
662+
// This is the second turn, after the function call
663+
// Check if the session contains the function call event
664+
List<Event> events = context.events();
665+
boolean hasFunctionCall =
666+
events.stream()
667+
.flatMap(
668+
e ->
669+
e
670+
.content()
671+
.flatMap(Content::parts)
672+
.orElse(ImmutableList.of())
673+
.stream())
674+
.anyMatch(p -> p.functionCall().isPresent());
675+
assertThat(hasFunctionCall).isTrue();
676+
}
677+
return Maybe.empty();
678+
});
679+
680+
var unused =
681+
runnerForRace
682+
.runAsync("user", sessionForRace.id(), createContent("start"))
683+
.toList()
684+
.blockingGet();
685+
686+
assertThat(callCount.get()).isEqualTo(2);
687+
}
688+
617689
@Test
618690
public void runAsync_withSessionKey_success() {
619691
var events =

0 commit comments

Comments
 (0)