Skip to content

Commit 5d8f85b

Browse files
Mateusz Krawieccopybara-github
authored andcommitted
fix: allow using legacy "transferToAgent(agentName)" to maintain backwards compatibility
PiperOrigin-RevId: 840661960
1 parent c642631 commit 5d8f85b

11 files changed

Lines changed: 150 additions & 50 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,27 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
5454
request.toBuilder()
5555
.appendInstructions(
5656
ImmutableList.of(buildTargetAgentsInstructions(agent, transferTargets)));
57+
58+
// Note: this tool is not exposed to the LLM in GenerateContent request. It is there only to
59+
// serve as a backwards-compatible instance for users who depend on the exact name of
60+
// "transferToAgent".
61+
builder.appendTools(ImmutableList.of(createTransferToAgentTool("legacyTransferToAgent")));
62+
63+
FunctionTool agentTransferTool = createTransferToAgentTool("transferToAgent");
64+
agentTransferTool.processLlmRequest(builder, ToolContext.builder(context).build());
65+
return Single.just(
66+
RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of()));
67+
}
68+
69+
private FunctionTool createTransferToAgentTool(String methodName) {
5770
Method transferToAgentMethod;
5871
try {
5972
transferToAgentMethod =
60-
AgentTransfer.class.getMethod("transferToAgent", String.class, ToolContext.class);
73+
AgentTransfer.class.getMethod(methodName, String.class, ToolContext.class);
6174
} catch (NoSuchMethodException e) {
6275
throw new IllegalStateException(e);
6376
}
64-
FunctionTool agentTransferTool = FunctionTool.create(transferToAgentMethod);
65-
agentTransferTool.processLlmRequest(builder, ToolContext.builder(context).build());
66-
return Single.just(
67-
RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of()));
77+
return FunctionTool.create(transferToAgentMethod);
6878
}
6979

7080
/** Builds a string with the target agent’s name and description. */
@@ -159,4 +169,18 @@ public static void transferToAgent(
159169
EventActions eventActions = toolContext.eventActions();
160170
toolContext.setActions(eventActions.toBuilder().transferToAgent(agentName).build());
161171
}
172+
173+
/**
174+
* Backwards compatible transferToAgent that uses camel-case naming instead of the ADK's
175+
* snake_case convention.
176+
*
177+
* <p>It exists only to support users who already use literal "transferToAgent" function call to
178+
* instruct ADK to transfer the question to another agent.
179+
*/
180+
@Schema(name = "transferToAgent")
181+
public static void legacyTransferToAgent(
182+
@Schema(name = "agentName") String agentName,
183+
@Schema(optional = true) ToolContext toolContext) {
184+
transferToAgent(agentName, toolContext);
185+
}
162186
}

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,15 @@ private Flowable<LlmResponse> callLlm(
231231
.runOnModelErrorCallback(
232232
new CallbackContext(
233233
context, eventForCallbackUsage.actions()),
234-
llmRequestBuilder,
234+
llmRequest,
235235
exception)
236236
.switchIfEmpty(Single.error(exception))
237237
.toFlowable())
238238
.doOnNext(
239239
llmResp -> {
240240
try (Scope innerScope = llmCallSpan.makeCurrent()) {
241241
Telemetry.traceCallLlm(
242-
context,
243-
eventForCallbackUsage.id(),
244-
llmRequestBuilder.build(),
245-
llmResp);
242+
context, eventForCallbackUsage.id(), llmRequest, llmResp);
246243
}
247244
})
248245
.doOnError(
@@ -272,7 +269,7 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
272269
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
273270

274271
Maybe<LlmResponse> pluginResult =
275-
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder);
272+
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build());
276273

277274
LlmAgent agent = (LlmAgent) context.agent();
278275

core/src/main/java/com/google/adk/plugins/BasePlugin.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,11 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba
122122
* Callback executed before a request is sent to the model.
123123
*
124124
* @param callbackContext The context for the current agent call.
125-
* @param llmRequest The mutable request builder, allowing modification of the request before it
126-
* is sent to the model.
125+
* @param llmRequest The prepared request object to be sent to the model.
127126
* @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally.
128127
*/
129128
public Maybe<LlmResponse> beforeModelCallback(
130-
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
129+
CallbackContext callbackContext, LlmRequest llmRequest) {
131130
return Maybe.empty();
132131
}
133132

@@ -148,13 +147,13 @@ public Maybe<LlmResponse> afterModelCallback(
148147
* Callback executed when a model call encounters an error.
149148
*
150149
* @param callbackContext The context for the current agent call.
151-
* @param llmRequest The mutable request builder for the request that failed.
150+
* @param llmRequest The request that was sent to the model.
152151
* @param error The exception that was raised.
153152
* @return An optional LlmResponse to use instead of propagating the error. Returning Empty to
154153
* allow the original error to be raised.
155154
*/
156155
public Maybe<LlmResponse> onModelErrorCallback(
157-
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
156+
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
158157
return Maybe.empty();
159158
}
160159

core/src/main/java/com/google/adk/plugins/LoggingPlugin.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,14 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba
151151

152152
@Override
153153
public Maybe<LlmResponse> beforeModelCallback(
154-
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
154+
CallbackContext callbackContext, LlmRequest llmRequest) {
155155
return Maybe.fromAction(
156156
() -> {
157-
LlmRequest request = llmRequest.build();
158157
log("🧠 LLM REQUEST");
159-
log(" Model: " + request.model().orElse("default"));
158+
log(" Model: " + llmRequest.model().orElse("default"));
160159
log(" Agent: " + callbackContext.agentName());
161160

162-
request
161+
llmRequest
163162
.getFirstSystemInstruction()
164163
.ifPresent(
165164
sysInstruction -> {
@@ -171,8 +170,8 @@ public Maybe<LlmResponse> beforeModelCallback(
171170
log(" System Instruction: '" + truncatedInstruction + "'");
172171
});
173172

174-
if (!request.tools().isEmpty()) {
175-
String toolNames = String.join(", ", request.tools().keySet());
173+
if (!llmRequest.tools().isEmpty()) {
174+
String toolNames = String.join(", ", llmRequest.tools().keySet());
176175
log(" Available Tools: [" + toolNames + "]");
177176
}
178177
});
@@ -212,7 +211,7 @@ public Maybe<LlmResponse> afterModelCallback(
212211

213212
@Override
214213
public Maybe<LlmResponse> onModelErrorCallback(
215-
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
214+
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
216215
return Maybe.fromAction(
217216
() -> {
218217
log("🧠 LLM ERROR");

core/src/main/java/com/google/adk/plugins/PluginManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public Maybe<Content> runAfterAgentCallback(BaseAgent agent, CallbackContext cal
127127
}
128128

129129
public Maybe<LlmResponse> runBeforeModelCallback(
130-
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
130+
CallbackContext callbackContext, LlmRequest llmRequest) {
131131
return runMaybeCallbacks(
132132
plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback");
133133
}
@@ -139,7 +139,7 @@ public Maybe<LlmResponse> runAfterModelCallback(
139139
}
140140

141141
public Maybe<LlmResponse> runOnModelErrorCallback(
142-
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
142+
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
143143
return runMaybeCallbacks(
144144
plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error),
145145
"onModelErrorCallback");

core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.google.adk.agents.RunConfig;
3030
import com.google.adk.agents.SequentialAgent;
3131
import com.google.adk.events.Event;
32+
import com.google.adk.models.LlmRequest;
3233
import com.google.adk.runner.InMemoryRunner;
3334
import com.google.adk.runner.Runner;
3435
import com.google.adk.sessions.Session;
@@ -411,6 +412,85 @@ public void testAutoToLoop() {
411412
assertThat(simplifyEvents(actualEvents)).containsExactly("root_agent: response5");
412413
}
413414

415+
@Test
416+
public void testLegacyTransferToAgent() {
417+
Content transferCallContent =
418+
Content.fromParts(
419+
Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", "sub_agent_1")));
420+
Content response1 = Content.fromParts(Part.fromText("response1"));
421+
Content response2 = Content.fromParts(Part.fromText("response2"));
422+
423+
TestLlm testLlm =
424+
createTestLlm(
425+
Flowable.just(createLlmResponse(transferCallContent)),
426+
Flowable.just(createLlmResponse(response1)),
427+
Flowable.just(createLlmResponse(response2)));
428+
429+
LlmAgent subAgent1 = createTestAgentBuilder(testLlm).name("sub_agent_1").build();
430+
LlmAgent rootAgent =
431+
createTestAgentBuilder(testLlm)
432+
.name("root_agent")
433+
.subAgents(ImmutableList.of(subAgent1))
434+
.build();
435+
InvocationContext invocationContext = createInvocationContext(rootAgent);
436+
437+
Runner runner = getRunnerAndCreateSession(rootAgent, invocationContext.session());
438+
List<Event> actualEvents = runRunner(runner, invocationContext);
439+
440+
assertThat(simplifyEvents(actualEvents))
441+
.containsExactly(
442+
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
443+
"root_agent: FunctionResponse(name=transferToAgent, response={})",
444+
"sub_agent_1: response1")
445+
.inOrder();
446+
447+
actualEvents = runRunner(runner, invocationContext);
448+
449+
assertThat(simplifyEvents(actualEvents)).containsExactly("sub_agent_1: response2");
450+
}
451+
452+
@Test
453+
public void testAgentTransferDoesNotExposeLegacyTransferToAgent() {
454+
Content transferCallContent =
455+
Content.fromParts(
456+
Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", "sub_agent_1")));
457+
Content response1 = Content.fromParts(Part.fromText("response1"));
458+
Content response2 = Content.fromParts(Part.fromText("response2"));
459+
TestLlm testLlm =
460+
createTestLlm(
461+
Flowable.just(createLlmResponse(transferCallContent)),
462+
Flowable.just(createLlmResponse(response1)),
463+
Flowable.just(createLlmResponse(response2)));
464+
LlmAgent subAgent1 = createTestAgentBuilder(testLlm).name("sub_agent_1").build();
465+
LlmAgent rootAgent =
466+
createTestAgentBuilder(testLlm)
467+
.name("root_agent")
468+
.subAgents(ImmutableList.of(subAgent1))
469+
.build();
470+
InvocationContext invocationContext = createInvocationContext(rootAgent);
471+
AgentTransfer processor = new AgentTransfer();
472+
LlmRequest request = LlmRequest.builder().build();
473+
474+
var processed = processor.processRequest(invocationContext, request);
475+
476+
assertThat(processed.blockingGet().updatedRequest().config().get().tools()).isPresent();
477+
assertThat(processed.blockingGet().updatedRequest().config().get().tools().get()).hasSize(1);
478+
assertThat(
479+
processed
480+
.blockingGet()
481+
.updatedRequest()
482+
.config()
483+
.get()
484+
.tools()
485+
.get()
486+
.get(0)
487+
.functionDeclarations()
488+
.get()
489+
.get(0)
490+
.name())
491+
.hasValue("transfer_to_agent");
492+
}
493+
414494
private Runner getRunnerAndCreateSession(LlmAgent agent, Session session) {
415495
Runner runner = new InMemoryRunner(agent, session.appName());
416496
// Ensure the session exists before running the agent.

core/src/test/java/com/google/adk/plugins/BasePluginTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private static class TestPlugin extends BasePlugin {
4343
private final CallbackContext callbackContext = Mockito.mock(CallbackContext.class);
4444
private final Content content = Content.builder().build();
4545
private final Event event = Mockito.mock(Event.class);
46-
private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
46+
private final LlmRequest llmRequest = LlmRequest.builder().build();
4747
private final LlmResponse llmResponse = LlmResponse.builder().build();
4848
private final ToolContext toolContext = Mockito.mock(ToolContext.class);
4949

@@ -79,7 +79,7 @@ public void afterAgentCallback_returnsEmptyMaybe() {
7979

8080
@Test
8181
public void beforeModelCallback_returnsEmptyMaybe() {
82-
plugin.beforeModelCallback(callbackContext, llmRequestBuilder).test().assertResult();
82+
plugin.beforeModelCallback(callbackContext, llmRequest).test().assertResult();
8383
}
8484

8585
@Test
@@ -90,7 +90,7 @@ public void afterModelCallback_returnsEmptyMaybe() {
9090
@Test
9191
public void onModelErrorCallback_returnsEmptyMaybe() {
9292
plugin
93-
.onModelErrorCallback(callbackContext, llmRequestBuilder, new RuntimeException())
93+
.onModelErrorCallback(callbackContext, llmRequest, new RuntimeException())
9494
.test()
9595
.assertResult();
9696
}

core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public class LoggingPluginTest {
6969
.actions(EventActions.builder().build())
7070
.longRunningToolIds(Optional.empty())
7171
.build();
72-
private final LlmRequest.Builder llmRequestBuilder =
73-
LlmRequest.builder().model("default").contents(ImmutableList.of());
72+
private final LlmRequest llmRequest =
73+
LlmRequest.builder().model("default").contents(ImmutableList.of()).build();
7474
private final LlmResponse llmResponse = LlmResponse.builder().build();
7575
private final ImmutableMap<String, Object> toolArgs = ImmutableMap.of();
7676
private final ImmutableMap<String, Object> toolResult = ImmutableMap.of();
@@ -175,10 +175,7 @@ public void afterAgentCallback_runsWithoutError() {
175175

176176
@Test
177177
public void beforeModelCallback_runsWithoutError() {
178-
loggingPlugin
179-
.beforeModelCallback(mockCallbackContext, llmRequestBuilder)
180-
.test()
181-
.assertComplete();
178+
loggingPlugin.beforeModelCallback(mockCallbackContext, llmRequest).test().assertComplete();
182179
}
183180

184181
@Test
@@ -187,7 +184,8 @@ public void beforeModelCallback_longSystemInstruction() {
187184
.beforeModelCallback(
188185
mockCallbackContext,
189186
LlmRequest.builder()
190-
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000))))
187+
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000)))
188+
.build())
191189
.test()
192190
.assertComplete();
193191
}
@@ -196,7 +194,8 @@ public void beforeModelCallback_longSystemInstruction() {
196194
public void beforeModelCallback_tools() {
197195
loggingPlugin
198196
.beforeModelCallback(
199-
mockCallbackContext, LlmRequest.builder().appendTools(ImmutableList.of(mockTool)))
197+
mockCallbackContext,
198+
LlmRequest.builder().appendTools(ImmutableList.of(mockTool)).build())
200199
.test()
201200
.assertComplete();
202201
}
@@ -232,7 +231,7 @@ public void afterModelCallback_usageMetadata() {
232231
@Test
233232
public void onModelErrorCallback_runsWithoutError() {
234233
loggingPlugin
235-
.onModelErrorCallback(mockCallbackContext, llmRequestBuilder, throwable)
234+
.onModelErrorCallback(mockCallbackContext, llmRequest, throwable)
236235
.test()
237236
.assertComplete();
238237
}

core/src/test/java/com/google/adk/plugins/PluginManagerTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,18 +236,18 @@ public void runAfterAgentCallback_singlePlugin() {
236236
@Test
237237
public void runBeforeModelCallback_singlePlugin() {
238238
CallbackContext mockCallbackContext = mock(CallbackContext.class);
239-
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
239+
LlmRequest llmRequest = LlmRequest.builder().build();
240240
LlmResponse llmResponse = LlmResponse.builder().build();
241241

242242
when(plugin1.beforeModelCallback(any(), any())).thenReturn(Maybe.just(llmResponse));
243243
pluginManager.registerPlugin(plugin1);
244244

245245
pluginManager
246-
.runBeforeModelCallback(mockCallbackContext, llmRequestBuilder)
246+
.runBeforeModelCallback(mockCallbackContext, llmRequest)
247247
.test()
248248
.assertResult(llmResponse);
249249

250-
verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequestBuilder);
250+
verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequest);
251251
}
252252

253253
@Test
@@ -269,19 +269,19 @@ public void runAfterModelCallback_singlePlugin() {
269269
@Test
270270
public void runOnModelErrorCallback_singlePlugin() {
271271
CallbackContext mockCallbackContext = mock(CallbackContext.class);
272-
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
272+
LlmRequest llmRequest = LlmRequest.builder().build();
273273
Throwable mockThrowable = mock(Throwable.class);
274274
LlmResponse llmResponse = LlmResponse.builder().build();
275275

276276
when(plugin1.onModelErrorCallback(any(), any(), any())).thenReturn(Maybe.just(llmResponse));
277277
pluginManager.registerPlugin(plugin1);
278278

279279
pluginManager
280-
.runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable)
280+
.runOnModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable)
281281
.test()
282282
.assertResult(llmResponse);
283283

284-
verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable);
284+
verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable);
285285
}
286286

287287
@Test

0 commit comments

Comments
 (0)