Skip to content

Commit b002cc0

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Extracting GeminiUtils from Gemini and adding tests
The tests found a bug in sanitizeRequest calling config.labels(null), which throws an NPE. Other than that fix and a tweak on the name of sanitizeRequest, the code is a copy of Gemini.java PiperOrigin-RevId: 803476406
1 parent d46673e commit b002cc0

3 files changed

Lines changed: 603 additions & 129 deletions

File tree

core/src/main/java/com/google/adk/models/Gemini.java

Lines changed: 9 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,13 @@
2020
import static com.google.common.collect.ImmutableList.toImmutableList;
2121

2222
import com.google.adk.Version;
23-
import com.google.common.collect.ImmutableList;
2423
import com.google.common.collect.ImmutableMap;
2524
import com.google.common.collect.Iterables;
2625
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2726
import com.google.genai.Client;
2827
import com.google.genai.ResponseStream;
29-
import com.google.genai.types.Blob;
3028
import com.google.genai.types.Candidate;
3129
import com.google.genai.types.Content;
32-
import com.google.genai.types.FileData;
3330
import com.google.genai.types.FinishReason;
3431
import com.google.genai.types.GenerateContentConfig;
3532
import com.google.genai.types.GenerateContentResponse;
@@ -226,60 +223,10 @@ public Gemini build() {
226223
private LlmRequest sanitizeRequest(LlmRequest llmRequest) {
227224
if (apiClient.vertexAI()) {
228225
return llmRequest;
226+
} else {
227+
// Using API key from Google AI Studio to call model doesn't support labels.
228+
return GeminiUtil.sanitizeRequestForGeminiApi(llmRequest);
229229
}
230-
LlmRequest.Builder requestBuilder = llmRequest.toBuilder();
231-
232-
// Using API key from Google AI Studio to call model doesn't support labels.
233-
llmRequest
234-
.config()
235-
.ifPresent(
236-
config -> {
237-
if (config.labels().isPresent()) {
238-
requestBuilder.config(config.toBuilder().labels(null).build());
239-
}
240-
});
241-
242-
if (llmRequest.contents().isEmpty()) {
243-
return requestBuilder.build();
244-
}
245-
246-
// This backend does not support the display_name parameter for file uploads,
247-
// so it must be removed to prevent request failures.
248-
ImmutableList<Content> updatedContents =
249-
llmRequest.contents().stream()
250-
.map(
251-
content -> {
252-
if (content.parts().isEmpty() || content.parts().get().isEmpty()) {
253-
return content;
254-
}
255-
256-
ImmutableList<Part> updatedParts =
257-
content.parts().get().stream()
258-
.map(
259-
part -> {
260-
Part.Builder partBuilder = part.toBuilder();
261-
if (part.inlineData().flatMap(Blob::displayName).isPresent()) {
262-
Blob blob = part.inlineData().get();
263-
Blob.Builder newBlobBuilder = Blob.builder();
264-
blob.data().ifPresent(newBlobBuilder::data);
265-
blob.mimeType().ifPresent(newBlobBuilder::mimeType);
266-
partBuilder.inlineData(newBlobBuilder.build());
267-
}
268-
if (part.fileData().flatMap(FileData::displayName).isPresent()) {
269-
FileData fileData = part.fileData().get();
270-
FileData.Builder newFileDataBuilder = FileData.builder();
271-
fileData.fileUri().ifPresent(newFileDataBuilder::fileUri);
272-
fileData.mimeType().ifPresent(newFileDataBuilder::mimeType);
273-
partBuilder.fileData(newFileDataBuilder.build());
274-
}
275-
return partBuilder.build();
276-
})
277-
.collect(toImmutableList());
278-
279-
return content.toBuilder().parts(updatedParts).build();
280-
})
281-
.collect(toImmutableList());
282-
return requestBuilder.contents(updatedContents).build();
283230
}
284231

285232
@Override
@@ -293,7 +240,7 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
293240
Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList());
294241
}
295242

296-
List<Content> finalContents = stripThoughts(contents);
243+
List<Content> finalContents = GeminiUtil.stripThoughts(contents);
297244
GenerateContentConfig config = llmRequest.config().orElse(null);
298245
String effectiveModelName = llmRequest.model().orElse(model());
299246

@@ -320,7 +267,8 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
320267

321268
List<LlmResponse> responsesToEmit = new ArrayList<>();
322269
LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse);
323-
String currentTextChunk = getTextFromLlmResponse(currentProcessedLlmResponse);
270+
String currentTextChunk =
271+
GeminiUtil.getTextFromLlmResponse(currentProcessedLlmResponse);
324272

325273
if (!currentTextChunk.isEmpty()) {
326274
accumulatedText.append(currentTextChunk);
@@ -329,17 +277,13 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
329277
responsesToEmit.add(partialResponse);
330278
} else {
331279
if (accumulatedText.length() > 0
332-
&& shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
280+
&& GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
333281
LlmResponse aggregatedTextResponse =
334282
LlmResponse.builder()
335283
.content(
336284
Content.builder()
337285
.role("model")
338-
.parts(
339-
ImmutableList.of(
340-
Part.builder()
341-
.text(accumulatedText.toString())
342-
.build()))
286+
.parts(Part.fromText(accumulatedText.toString()))
343287
.build())
344288
.build();
345289
responsesToEmit.add(aggregatedTextResponse);
@@ -376,11 +320,7 @@ && shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
376320
.content(
377321
Content.builder()
378322
.role("model")
379-
.parts(
380-
ImmutableList.of(
381-
Part.builder()
382-
.text(accumulatedText.toString())
383-
.build()))
323+
.parts(Part.fromText(accumulatedText.toString()))
384324
.build())
385325
.build();
386326
return Flowable.just(finalAggregatedTextResponse);
@@ -400,52 +340,6 @@ && shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
400340
}
401341
}
402342

403-
/**
404-
* Extracts text content from the first part of an LlmResponse, if available.
405-
*
406-
* @param llmResponse The LlmResponse to extract text from.
407-
* @return The text content, or an empty string if not found.
408-
*/
409-
private String getTextFromLlmResponse(LlmResponse llmResponse) {
410-
return llmResponse
411-
.content()
412-
.flatMap(Content::parts)
413-
.filter(parts -> !parts.isEmpty())
414-
.map(parts -> parts.get(0))
415-
.flatMap(Part::text)
416-
.orElse("");
417-
}
418-
419-
/**
420-
* Determines if accumulated text should be emitted based on the current LlmResponse. We flush if
421-
* current response is not a text continuation (e.g., no content, no parts, or the first part is
422-
* not inline_data, meaning it's something else or just empty, thereby warranting a flush of
423-
* preceding text).
424-
*
425-
* @param currentLlmResponse The current LlmResponse being processed.
426-
* @return True if accumulated text should be emitted, false otherwise.
427-
*/
428-
private boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse) {
429-
Optional<Content> contentOpt = currentLlmResponse.content();
430-
if (contentOpt.isEmpty()) {
431-
return true;
432-
}
433-
434-
Optional<List<Part>> partsOpt = contentOpt.get().parts();
435-
if (partsOpt.isEmpty() || partsOpt.get().isEmpty()) {
436-
return true;
437-
}
438-
439-
// If content and parts are present, and parts list is not empty, we want to yield accumulated
440-
// text only if `text` is present AND (`not llm_response.content` OR `not
441-
// llm_response.content.parts` OR `not llm_response.content.parts[0].inline_data`)
442-
// This means we flush if the first part does NOT have inline_data.
443-
// If it *has* inline_data, the condition below is false,
444-
// and we would not flush based on this specific sub-condition.
445-
Part firstPart = partsOpt.get().get(0);
446-
return firstPart.inlineData().isEmpty();
447-
}
448-
449343
@Override
450344
public BaseLlmConnection connect(LlmRequest llmRequest) {
451345
llmRequest = sanitizeRequest(llmRequest);
@@ -458,18 +352,4 @@ public BaseLlmConnection connect(LlmRequest llmRequest) {
458352

459353
return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig);
460354
}
461-
462-
/** Removes any `Part` that contains only a `thought` from the content list. */
463-
private List<Content> stripThoughts(List<Content> originalContents) {
464-
List<Content> updatedContents = new ArrayList<>();
465-
for (Content content : originalContents) {
466-
ImmutableList<Part> nonThoughtParts =
467-
content.parts().orElse(ImmutableList.of()).stream()
468-
// Keep if thought is not present OR if thought is present but false
469-
.filter(part -> part.thought().map(isThought -> !isThought).orElse(true))
470-
.collect(toImmutableList());
471-
updatedContents.add(content.toBuilder().parts(nonThoughtParts).build());
472-
}
473-
return updatedContents;
474-
}
475355
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.adk.models;
17+
18+
import static com.google.common.collect.ImmutableList.toImmutableList;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableMap;
22+
import com.google.genai.types.Blob;
23+
import com.google.genai.types.Content;
24+
import com.google.genai.types.FileData;
25+
import com.google.genai.types.Part;
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
import java.util.Optional;
29+
30+
/** Request / Response utilities for {@link Gemini}. */
31+
public final class GeminiUtil {
32+
33+
private GeminiUtil() {}
34+
35+
/**
36+
* Sanitizes the request to ensure it is compatible with the Gemini API backend. Required as there
37+
* are some parameters that if included in the request will raise a runtime error if sent to the
38+
* wrong backend (e.g. image names only work on Vertex AI).
39+
*
40+
* @param llmRequest The request to sanitize.
41+
* @return The sanitized request.
42+
*/
43+
public static LlmRequest sanitizeRequestForGeminiApi(LlmRequest llmRequest) {
44+
LlmRequest.Builder requestBuilder = llmRequest.toBuilder();
45+
llmRequest
46+
.config()
47+
.ifPresent(
48+
config -> {
49+
if (config.labels().isPresent()) {
50+
requestBuilder.config(config.toBuilder().labels(ImmutableMap.of()).build());
51+
}
52+
});
53+
54+
if (llmRequest.contents().isEmpty()) {
55+
return requestBuilder.build();
56+
}
57+
58+
// This backend does not support the display_name parameter for file uploads,
59+
// so it must be removed to prevent request failures.
60+
ImmutableList<Content> updatedContents =
61+
llmRequest.contents().stream()
62+
.map(
63+
content -> {
64+
if (content.parts().isEmpty() || content.parts().get().isEmpty()) {
65+
return content;
66+
}
67+
68+
ImmutableList<Part> updatedParts =
69+
content.parts().get().stream()
70+
.map(
71+
part -> {
72+
Part.Builder partBuilder = part.toBuilder();
73+
if (part.inlineData().flatMap(Blob::displayName).isPresent()) {
74+
Blob blob = part.inlineData().get();
75+
Blob.Builder newBlobBuilder = Blob.builder();
76+
blob.data().ifPresent(newBlobBuilder::data);
77+
blob.mimeType().ifPresent(newBlobBuilder::mimeType);
78+
partBuilder.inlineData(newBlobBuilder.build());
79+
}
80+
if (part.fileData().flatMap(FileData::displayName).isPresent()) {
81+
FileData fileData = part.fileData().get();
82+
FileData.Builder newFileDataBuilder = FileData.builder();
83+
fileData.fileUri().ifPresent(newFileDataBuilder::fileUri);
84+
fileData.mimeType().ifPresent(newFileDataBuilder::mimeType);
85+
partBuilder.fileData(newFileDataBuilder.build());
86+
}
87+
return partBuilder.build();
88+
})
89+
.collect(toImmutableList());
90+
91+
return content.toBuilder().parts(updatedParts).build();
92+
})
93+
.collect(toImmutableList());
94+
return requestBuilder.contents(updatedContents).build();
95+
}
96+
97+
/**
98+
* Extracts text content from the first part of an LlmResponse, if available.
99+
*
100+
* @param llmResponse The LlmResponse to extract text from.
101+
* @return The text content, or an empty string if not found.
102+
*/
103+
public static String getTextFromLlmResponse(LlmResponse llmResponse) {
104+
return llmResponse
105+
.content()
106+
.flatMap(Content::parts)
107+
.filter(parts -> !parts.isEmpty())
108+
.map(parts -> parts.get(0))
109+
.flatMap(Part::text)
110+
.orElse("");
111+
}
112+
113+
/**
114+
* Determines if accumulated text should be emitted based on the current LlmResponse. We flush if
115+
* current response is not a text continuation (e.g., no content, no parts, or the first part is
116+
* not inline_data, meaning it's something else or just empty, thereby warranting a flush of
117+
* preceding text).
118+
*
119+
* @param currentLlmResponse The current LlmResponse being processed.
120+
* @return True if accumulated text should be emitted, false otherwise.
121+
*/
122+
public static boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse) {
123+
Optional<Content> contentOpt = currentLlmResponse.content();
124+
if (contentOpt.isEmpty()) {
125+
return true;
126+
}
127+
128+
Optional<List<Part>> partsOpt = contentOpt.get().parts();
129+
if (partsOpt.isEmpty() || partsOpt.get().isEmpty()) {
130+
return true;
131+
}
132+
133+
// If content and parts are present, and parts list is not empty, we want to yield accumulated
134+
// text only if `text` is present AND (`not llm_response.content` OR `not
135+
// llm_response.content.parts` OR `not llm_response.content.parts[0].inline_data`)
136+
// This means we flush if the first part does NOT have inline_data.
137+
// If it *has* inline_data, the condition below is false,
138+
// and we would not flush based on this specific sub-condition.
139+
Part firstPart = partsOpt.get().get(0);
140+
return firstPart.inlineData().isEmpty();
141+
}
142+
143+
/** Removes any `Part` that contains only a `thought` from the content list. */
144+
public static List<Content> stripThoughts(List<Content> originalContents) {
145+
List<Content> updatedContents = new ArrayList<>();
146+
for (Content content : originalContents) {
147+
ImmutableList<Part> nonThoughtParts =
148+
content.parts().orElse(ImmutableList.of()).stream()
149+
// Keep if thought is not present OR if thought is present but false
150+
.filter(part -> part.thought().map(isThought -> !isThought).orElse(true))
151+
.collect(toImmutableList());
152+
updatedContents.add(content.toBuilder().parts(nonThoughtParts).build());
153+
}
154+
return updatedContents;
155+
}
156+
}

0 commit comments

Comments
 (0)