Skip to content

Commit 69680bb

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
Ensure that events are appended to the session and processed sequentially before proceeding to the next step in BaseLlmFlow. PiperOrigin-RevId: 899609964
1 parent 9031cad commit 69680bb

4 files changed

Lines changed: 27 additions & 184 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -461,31 +461,14 @@ public Flowable<Event> run(InvocationContext invocationContext) {
461461

462462
private Flowable<Event> run(
463463
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
464-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext);
465-
466-
Flowable<Event> processedEvents =
467-
currentStepEvents
468-
.concatMap(
469-
event ->
470-
invocationContext
471-
.sessionService()
472-
.appendEvent(invocationContext.session(), event)
473-
.flatMap(
474-
registeredEvent ->
475-
invocationContext
476-
.pluginManager()
477-
.onEventCallback(invocationContext, registeredEvent)
478-
.defaultIfEmpty(registeredEvent))
479-
.toFlowable())
480-
.cache();
481-
464+
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
482465
if (stepsCompleted + 1 >= maxSteps) {
483466
logger.debug("Ending flow execution because max steps reached.");
484-
return processedEvents;
467+
return currentStepEvents;
485468
}
486469

487-
return processedEvents.concatWith(
488-
processedEvents
470+
return currentStepEvents.concatWith(
471+
currentStepEvents
489472
.toList()
490473
.flatMapPublisher(
491474
eventList -> {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ public final class Functions {
7171
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
7272

7373
/** Generates a unique ID for a function call. */
74-
public static String generateClientFunctionCallId(FunctionCall functionCall) {
75-
String source =
76-
functionCall.name().orElse("") + functionCall.args().orElse(ImmutableMap.of()).toString();
77-
return AF_FUNCTION_CALL_ID_PREFIX + UUID.nameUUIDFromBytes(source.getBytes()).toString();
74+
public static String generateClientFunctionCallId() {
75+
return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID();
7876
}
7977

8078
/**
@@ -103,7 +101,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
103101
FunctionCall functionCall = part.functionCall().get();
104102
if (functionCall.id().isEmpty() || functionCall.id().get().isEmpty()) {
105103
FunctionCall updatedFunctionCall =
106-
functionCall.toBuilder().id(generateClientFunctionCallId(functionCall)).build();
104+
functionCall.toBuilder().id(generateClientFunctionCallId()).build();
107105
newParts.add(part.toBuilder().functionCall(updatedFunctionCall).build());
108106
modified = true;
109107
} else {
@@ -623,7 +621,7 @@ private static Event buildResponseEvent(
623621
.build();
624622

625623
return Event.builder()
626-
.id(toolContext.functionCallId().orElseGet(Event::generateEventId))
624+
.id(Event.generateEventId())
627625
.invocationId(invocationContext.invocationId())
628626
.author(invocationContext.agent().name())
629627
.branch(invocationContext.branch().orElse(null))
@@ -659,17 +657,17 @@ public static Optional<Event> generateRequestConfirmationEvent(
659657
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))
660658
.entrySet()) {
661659

662-
FunctionCall.Builder builder =
660+
FunctionCall requestConfirmationFunctionCall =
663661
FunctionCall.builder()
664662
.name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
665663
.args(
666664
ImmutableMap.of(
667665
"originalFunctionCall",
668666
functionCallsById.get(entry.getKey()),
669667
"toolConfirmation",
670-
entry.getValue()));
671-
FunctionCall requestConfirmationFunctionCall =
672-
builder.id(generateClientFunctionCallId(builder.build())).build();
668+
entry.getValue()))
669+
.id(generateClientFunctionCallId())
670+
.build();
673671

674672
longRunningToolIds.add(requestConfirmationFunctionCall.id().get());
675673
parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build());
@@ -682,15 +680,8 @@ public static Optional<Event> generateRequestConfirmationEvent(
682680
var contentBuilder = Content.builder().parts(parts);
683681
functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role);
684682

685-
String deterministicId =
686-
"req-conf-"
687-
+ functionResponseEvent.actions().requestedToolConfirmations().keySet().stream()
688-
.sorted()
689-
.collect(java.util.stream.Collectors.joining("-"));
690-
691683
return Optional.of(
692684
Event.builder()
693-
.id(deterministicId)
694685
.invocationId(invocationContext.invocationId())
695686
.author(invocationContext.agent().name())
696687
.branch(invocationContext.branch().orElse(null))

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,9 @@
6868
import java.util.concurrent.ConcurrentHashMap;
6969
import java.util.concurrent.ConcurrentMap;
7070
import org.jspecify.annotations.Nullable;
71-
import org.slf4j.Logger;
72-
import org.slf4j.LoggerFactory;
7371

7472
/** The main class for the GenAI Agents runner. */
7573
public class Runner {
76-
private static final Logger logger = LoggerFactory.getLogger(Runner.class);
7774
private final BaseAgent agent;
7875
private final String appName;
7976
private final BaseArtifactService artifactService;
@@ -573,28 +570,19 @@ private Flowable<Event> runAgentWithUpdatedSession(
573570
.agent()
574571
.runAsync(contextWithUpdatedSession)
575572
.concatMap(
576-
agentEvent -> {
577-
// TODO: remove this hack after deprecating runAsync with Session.
578-
copySessionStates(updatedSession, initialContext.session());
579-
580-
// TODO: b/502182243 - Investigate if appendEvent should be made idempotent in
581-
// SessionService to avoid this check.
582-
if (updatedSession.events().stream()
583-
.anyMatch(e -> e.id() != null && e.id().equals(agentEvent.id()))) {
584-
logger.debug("Event {} already in session, skipping append", agentEvent.id());
585-
return io.reactivex.rxjava3.core.Flowable.just(agentEvent);
586-
}
587-
return this.sessionService
588-
.appendEvent(updatedSession, agentEvent)
589-
.flatMap(
590-
registeredEvent -> {
591-
return contextWithUpdatedSession
592-
.pluginManager()
593-
.onEventCallback(contextWithUpdatedSession, registeredEvent)
594-
.defaultIfEmpty(registeredEvent);
595-
})
596-
.toFlowable();
597-
});
573+
agentEvent ->
574+
this.sessionService
575+
.appendEvent(updatedSession, agentEvent)
576+
.flatMap(
577+
registeredEvent -> {
578+
// TODO: remove this hack after deprecating runAsync with Session.
579+
copySessionStates(updatedSession, initialContext.session());
580+
return contextWithUpdatedSession
581+
.pluginManager()
582+
.onEventCallback(contextWithUpdatedSession, registeredEvent)
583+
.defaultIfEmpty(registeredEvent);
584+
})
585+
.toFlowable());
598586

599587
// If beforeRunCallback returns content, emit it and skip agent
600588
Context capturedContext = Context.current();

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,9 @@
4646
import com.google.adk.artifacts.BaseArtifactService;
4747
import com.google.adk.events.Event;
4848
import com.google.adk.flows.llmflows.Functions;
49-
import com.google.adk.models.LlmRequest;
5049
import com.google.adk.models.LlmResponse;
5150
import com.google.adk.plugins.BasePlugin;
5251
import com.google.adk.sessions.BaseSessionService;
53-
import com.google.adk.sessions.GetSessionConfig;
54-
import com.google.adk.sessions.InMemorySessionService;
5552
import com.google.adk.sessions.Session;
5653
import com.google.adk.sessions.SessionKey;
5754
import com.google.adk.summarizer.EventsCompactionConfig;
@@ -83,7 +80,6 @@
8380
import java.time.Instant;
8481
import java.util.ArrayList;
8582
import java.util.List;
86-
import java.util.Map;
8783
import java.util.Objects;
8884
import java.util.Optional;
8985
import java.util.UUID;
@@ -592,22 +588,12 @@ public void onToolErrorCallback_error() {
592588
@Test
593589
public void onEventCallback_success() {
594590
when(plugin.onEventCallback(any(), any()))
595-
.thenAnswer(
596-
invocation -> {
597-
Event event = invocation.getArgument(1);
598-
return Maybe.just(
599-
Event.builder()
600-
.id(event.id())
601-
.invocationId(event.invocationId())
602-
.author("model")
603-
.content(createContent("from plugin"))
604-
.build());
605-
});
591+
.thenReturn(Maybe.just(TestUtils.createEvent("form plugin")));
606592

607593
List<Event> events =
608594
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
609595

610-
assertThat(simplifyEvents(events)).containsExactly("model: from plugin");
596+
assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin");
611597

612598
verify(plugin).onEventCallback(any(), any());
613599
}
@@ -1700,109 +1686,4 @@ public void runner_executesSaveArtifactFlow() {
17001686
// agent was run
17011687
assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm");
17021688
}
1703-
1704-
@Test
1705-
public void runAsync_ensuresSequentialConsistencyForTools() {
1706-
// Arrange
1707-
TestLlm testLlm =
1708-
createTestLlm(
1709-
createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")),
1710-
createTextLlmResponse("Final response"));
1711-
1712-
LlmAgent agent =
1713-
createTestAgentBuilder(testLlm)
1714-
.tools(
1715-
ImmutableList.of(
1716-
FunctionTool.create(RaceConditionTools.class, "tool1"),
1717-
FunctionTool.create(RaceConditionTools.class, "tool2")))
1718-
.build();
1719-
1720-
BaseSessionService delegate = new InMemorySessionService();
1721-
BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 0);
1722-
1723-
Runner runner =
1724-
Runner.builder()
1725-
.app(App.builder().name("test").rootAgent(agent).build())
1726-
.sessionService(delayedSessionService)
1727-
.build();
1728-
Session session = runner.sessionService().createSession("test", "user").blockingGet();
1729-
1730-
// Act
1731-
var unused =
1732-
runner
1733-
.runAsync("user", session.id(), Content.fromParts(Part.fromText("start")))
1734-
.toList()
1735-
.blockingGet();
1736-
1737-
// Assert
1738-
ImmutableList<LlmRequest> requests = ImmutableList.copyOf(testLlm.getRequests());
1739-
assertThat(requests).hasSize(2);
1740-
1741-
// Second request should contain the result of tool1
1742-
LlmRequest secondRequest = requests.get(1);
1743-
List<Content> history = secondRequest.contents();
1744-
1745-
boolean foundToolResponse =
1746-
history.stream()
1747-
.flatMap(content -> content.parts().stream().flatMap(List::stream))
1748-
.filter(part -> part.functionResponse().isPresent())
1749-
.map(part -> part.functionResponse().get())
1750-
.anyMatch(
1751-
response ->
1752-
response.name().orElse("").equals("tool1")
1753-
&& response
1754-
.response()
1755-
.map(
1756-
r ->
1757-
java.util.Objects.equals(
1758-
r, ImmutableMap.of("result", "result_value1")))
1759-
.orElse(false));
1760-
1761-
assertThat(foundToolResponse).isTrue();
1762-
}
1763-
1764-
@SuppressWarnings({"unchecked", "deprecation"})
1765-
private static BaseSessionService createDelayedSessionService(
1766-
BaseSessionService delegate, long delayMs) {
1767-
BaseSessionService delayedSessionService = mock(BaseSessionService.class);
1768-
when(delayedSessionService.createSession(anyString(), anyString(), any(Map.class), anyString()))
1769-
.thenAnswer(
1770-
inv ->
1771-
delegate.createSession(
1772-
(String) inv.getArgument(0),
1773-
(String) inv.getArgument(1),
1774-
(Map<String, Object>) inv.getArgument(2),
1775-
(String) inv.getArgument(3)));
1776-
when(delayedSessionService.createSession(anyString(), anyString()))
1777-
.thenAnswer(
1778-
inv ->
1779-
delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1)));
1780-
when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any()))
1781-
.thenAnswer(
1782-
inv ->
1783-
delegate.getSession(
1784-
(String) inv.getArgument(0),
1785-
(String) inv.getArgument(1),
1786-
(String) inv.getArgument(2),
1787-
(Optional<GetSessionConfig>) inv.getArgument(3)));
1788-
when(delayedSessionService.appendEvent(any(), any()))
1789-
.thenAnswer(
1790-
inv ->
1791-
delegate
1792-
.appendEvent(inv.getArgument(0), inv.getArgument(1))
1793-
.delay(delayMs, MILLISECONDS));
1794-
return delayedSessionService;
1795-
}
1796-
1797-
public static class RaceConditionTools {
1798-
private RaceConditionTools() {}
1799-
1800-
public static ImmutableMap<String, Object> tool1(String arg) {
1801-
return ImmutableMap.of("result", "result_" + arg);
1802-
}
1803-
1804-
public static ImmutableMap<String, Object> tool2(String input) {
1805-
return ImmutableMap.of("status", "received_" + input);
1806-
}
1807-
}
18081689
}

0 commit comments

Comments
 (0)