Skip to content

Commit 9031cad

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
Ensure that events are appended to the session and processed sequentially before proceeding to the next step in BaseLlmFlow. PiperOrigin-RevId: 899568665
1 parent 78766c1 commit 9031cad

4 files changed

Lines changed: 184 additions & 27 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,31 @@ public Flowable<Event> run(InvocationContext invocationContext) {
461461

462462
private Flowable<Event> run(
463463
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
464-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
464+
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext);
465+
466+
Flowable<Event> processedEvents =
467+
currentStepEvents
468+
.concatMap(
469+
event ->
470+
invocationContext
471+
.sessionService()
472+
.appendEvent(invocationContext.session(), event)
473+
.flatMap(
474+
registeredEvent ->
475+
invocationContext
476+
.pluginManager()
477+
.onEventCallback(invocationContext, registeredEvent)
478+
.defaultIfEmpty(registeredEvent))
479+
.toFlowable())
480+
.cache();
481+
465482
if (stepsCompleted + 1 >= maxSteps) {
466483
logger.debug("Ending flow execution because max steps reached.");
467-
return currentStepEvents;
484+
return processedEvents;
468485
}
469486

470-
return currentStepEvents.concatWith(
471-
currentStepEvents
487+
return processedEvents.concatWith(
488+
processedEvents
472489
.toList()
473490
.flatMapPublisher(
474491
eventList -> {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ public final class Functions {
7171
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
7272

7373
/** Generates a unique ID for a function call. */
74-
public static String generateClientFunctionCallId() {
75-
return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID();
74+
public static String generateClientFunctionCallId(FunctionCall functionCall) {
75+
String source =
76+
functionCall.name().orElse("") + functionCall.args().orElse(ImmutableMap.of()).toString();
77+
return AF_FUNCTION_CALL_ID_PREFIX + UUID.nameUUIDFromBytes(source.getBytes()).toString();
7678
}
7779

7880
/**
@@ -101,7 +103,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
101103
FunctionCall functionCall = part.functionCall().get();
102104
if (functionCall.id().isEmpty() || functionCall.id().get().isEmpty()) {
103105
FunctionCall updatedFunctionCall =
104-
functionCall.toBuilder().id(generateClientFunctionCallId()).build();
106+
functionCall.toBuilder().id(generateClientFunctionCallId(functionCall)).build();
105107
newParts.add(part.toBuilder().functionCall(updatedFunctionCall).build());
106108
modified = true;
107109
} else {
@@ -621,7 +623,7 @@ private static Event buildResponseEvent(
621623
.build();
622624

623625
return Event.builder()
624-
.id(Event.generateEventId())
626+
.id(toolContext.functionCallId().orElseGet(Event::generateEventId))
625627
.invocationId(invocationContext.invocationId())
626628
.author(invocationContext.agent().name())
627629
.branch(invocationContext.branch().orElse(null))
@@ -657,17 +659,17 @@ public static Optional<Event> generateRequestConfirmationEvent(
657659
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))
658660
.entrySet()) {
659661

660-
FunctionCall requestConfirmationFunctionCall =
662+
FunctionCall.Builder builder =
661663
FunctionCall.builder()
662664
.name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
663665
.args(
664666
ImmutableMap.of(
665667
"originalFunctionCall",
666668
functionCallsById.get(entry.getKey()),
667669
"toolConfirmation",
668-
entry.getValue()))
669-
.id(generateClientFunctionCallId())
670-
.build();
670+
entry.getValue()));
671+
FunctionCall requestConfirmationFunctionCall =
672+
builder.id(generateClientFunctionCallId(builder.build())).build();
671673

672674
longRunningToolIds.add(requestConfirmationFunctionCall.id().get());
673675
parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build());
@@ -680,8 +682,15 @@ public static Optional<Event> generateRequestConfirmationEvent(
680682
var contentBuilder = Content.builder().parts(parts);
681683
functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role);
682684

685+
String deterministicId =
686+
"req-conf-"
687+
+ functionResponseEvent.actions().requestedToolConfirmations().keySet().stream()
688+
.sorted()
689+
.collect(java.util.stream.Collectors.joining("-"));
690+
683691
return Optional.of(
684692
Event.builder()
693+
.id(deterministicId)
685694
.invocationId(invocationContext.invocationId())
686695
.author(invocationContext.agent().name())
687696
.branch(invocationContext.branch().orElse(null))

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,12 @@
6868
import java.util.concurrent.ConcurrentHashMap;
6969
import java.util.concurrent.ConcurrentMap;
7070
import org.jspecify.annotations.Nullable;
71+
import org.slf4j.Logger;
72+
import org.slf4j.LoggerFactory;
7173

7274
/** The main class for the GenAI Agents runner. */
7375
public class Runner {
76+
private static final Logger logger = LoggerFactory.getLogger(Runner.class);
7477
private final BaseAgent agent;
7578
private final String appName;
7679
private final BaseArtifactService artifactService;
@@ -570,19 +573,28 @@ private Flowable<Event> runAgentWithUpdatedSession(
570573
.agent()
571574
.runAsync(contextWithUpdatedSession)
572575
.concatMap(
573-
agentEvent ->
574-
this.sessionService
575-
.appendEvent(updatedSession, agentEvent)
576-
.flatMap(
577-
registeredEvent -> {
578-
// TODO: remove this hack after deprecating runAsync with Session.
579-
copySessionStates(updatedSession, initialContext.session());
580-
return contextWithUpdatedSession
581-
.pluginManager()
582-
.onEventCallback(contextWithUpdatedSession, registeredEvent)
583-
.defaultIfEmpty(registeredEvent);
584-
})
585-
.toFlowable());
576+
agentEvent -> {
577+
// TODO: remove this hack after deprecating runAsync with Session.
578+
copySessionStates(updatedSession, initialContext.session());
579+
580+
// TODO: b/502182243 - Investigate if appendEvent should be made idempotent in
581+
// SessionService to avoid this check.
582+
if (updatedSession.events().stream()
583+
.anyMatch(e -> e.id() != null && e.id().equals(agentEvent.id()))) {
584+
logger.debug("Event {} already in session, skipping append", agentEvent.id());
585+
return io.reactivex.rxjava3.core.Flowable.just(agentEvent);
586+
}
587+
return this.sessionService
588+
.appendEvent(updatedSession, agentEvent)
589+
.flatMap(
590+
registeredEvent -> {
591+
return contextWithUpdatedSession
592+
.pluginManager()
593+
.onEventCallback(contextWithUpdatedSession, registeredEvent)
594+
.defaultIfEmpty(registeredEvent);
595+
})
596+
.toFlowable();
597+
});
586598

587599
// If beforeRunCallback returns content, emit it and skip agent
588600
Context capturedContext = Context.current();

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@
4646
import com.google.adk.artifacts.BaseArtifactService;
4747
import com.google.adk.events.Event;
4848
import com.google.adk.flows.llmflows.Functions;
49+
import com.google.adk.models.LlmRequest;
4950
import com.google.adk.models.LlmResponse;
5051
import com.google.adk.plugins.BasePlugin;
5152
import com.google.adk.sessions.BaseSessionService;
53+
import com.google.adk.sessions.GetSessionConfig;
54+
import com.google.adk.sessions.InMemorySessionService;
5255
import com.google.adk.sessions.Session;
5356
import com.google.adk.sessions.SessionKey;
5457
import com.google.adk.summarizer.EventsCompactionConfig;
@@ -80,6 +83,7 @@
8083
import java.time.Instant;
8184
import java.util.ArrayList;
8285
import java.util.List;
86+
import java.util.Map;
8387
import java.util.Objects;
8488
import java.util.Optional;
8589
import java.util.UUID;
@@ -588,12 +592,22 @@ public void onToolErrorCallback_error() {
588592
@Test
589593
public void onEventCallback_success() {
590594
when(plugin.onEventCallback(any(), any()))
591-
.thenReturn(Maybe.just(TestUtils.createEvent("form plugin")));
595+
.thenAnswer(
596+
invocation -> {
597+
Event event = invocation.getArgument(1);
598+
return Maybe.just(
599+
Event.builder()
600+
.id(event.id())
601+
.invocationId(event.invocationId())
602+
.author("model")
603+
.content(createContent("from plugin"))
604+
.build());
605+
});
592606

593607
List<Event> events =
594608
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
595609

596-
assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin");
610+
assertThat(simplifyEvents(events)).containsExactly("model: from plugin");
597611

598612
verify(plugin).onEventCallback(any(), any());
599613
}
@@ -1686,4 +1700,109 @@ public void runner_executesSaveArtifactFlow() {
16861700
// agent was run
16871701
assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm");
16881702
}
1703+
1704+
@Test
1705+
public void runAsync_ensuresSequentialConsistencyForTools() {
1706+
// Arrange
1707+
TestLlm testLlm =
1708+
createTestLlm(
1709+
createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")),
1710+
createTextLlmResponse("Final response"));
1711+
1712+
LlmAgent agent =
1713+
createTestAgentBuilder(testLlm)
1714+
.tools(
1715+
ImmutableList.of(
1716+
FunctionTool.create(RaceConditionTools.class, "tool1"),
1717+
FunctionTool.create(RaceConditionTools.class, "tool2")))
1718+
.build();
1719+
1720+
BaseSessionService delegate = new InMemorySessionService();
1721+
BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 0);
1722+
1723+
Runner runner =
1724+
Runner.builder()
1725+
.app(App.builder().name("test").rootAgent(agent).build())
1726+
.sessionService(delayedSessionService)
1727+
.build();
1728+
Session session = runner.sessionService().createSession("test", "user").blockingGet();
1729+
1730+
// Act
1731+
var unused =
1732+
runner
1733+
.runAsync("user", session.id(), Content.fromParts(Part.fromText("start")))
1734+
.toList()
1735+
.blockingGet();
1736+
1737+
// Assert
1738+
ImmutableList<LlmRequest> requests = ImmutableList.copyOf(testLlm.getRequests());
1739+
assertThat(requests).hasSize(2);
1740+
1741+
// Second request should contain the result of tool1
1742+
LlmRequest secondRequest = requests.get(1);
1743+
List<Content> history = secondRequest.contents();
1744+
1745+
boolean foundToolResponse =
1746+
history.stream()
1747+
.flatMap(content -> content.parts().stream().flatMap(List::stream))
1748+
.filter(part -> part.functionResponse().isPresent())
1749+
.map(part -> part.functionResponse().get())
1750+
.anyMatch(
1751+
response ->
1752+
response.name().orElse("").equals("tool1")
1753+
&& response
1754+
.response()
1755+
.map(
1756+
r ->
1757+
java.util.Objects.equals(
1758+
r, ImmutableMap.of("result", "result_value1")))
1759+
.orElse(false));
1760+
1761+
assertThat(foundToolResponse).isTrue();
1762+
}
1763+
1764+
@SuppressWarnings({"unchecked", "deprecation"})
1765+
private static BaseSessionService createDelayedSessionService(
1766+
BaseSessionService delegate, long delayMs) {
1767+
BaseSessionService delayedSessionService = mock(BaseSessionService.class);
1768+
when(delayedSessionService.createSession(anyString(), anyString(), any(Map.class), anyString()))
1769+
.thenAnswer(
1770+
inv ->
1771+
delegate.createSession(
1772+
(String) inv.getArgument(0),
1773+
(String) inv.getArgument(1),
1774+
(Map<String, Object>) inv.getArgument(2),
1775+
(String) inv.getArgument(3)));
1776+
when(delayedSessionService.createSession(anyString(), anyString()))
1777+
.thenAnswer(
1778+
inv ->
1779+
delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1)));
1780+
when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any()))
1781+
.thenAnswer(
1782+
inv ->
1783+
delegate.getSession(
1784+
(String) inv.getArgument(0),
1785+
(String) inv.getArgument(1),
1786+
(String) inv.getArgument(2),
1787+
(Optional<GetSessionConfig>) inv.getArgument(3)));
1788+
when(delayedSessionService.appendEvent(any(), any()))
1789+
.thenAnswer(
1790+
inv ->
1791+
delegate
1792+
.appendEvent(inv.getArgument(0), inv.getArgument(1))
1793+
.delay(delayMs, MILLISECONDS));
1794+
return delayedSessionService;
1795+
}
1796+
1797+
public static class RaceConditionTools {
1798+
private RaceConditionTools() {}
1799+
1800+
public static ImmutableMap<String, Object> tool1(String arg) {
1801+
return ImmutableMap.of("result", "result_" + arg);
1802+
}
1803+
1804+
public static ImmutableMap<String, Object> tool2(String input) {
1805+
return ImmutableMap.of("status", "received_" + input);
1806+
}
1807+
}
16891808
}

0 commit comments

Comments
 (0)