Skip to content

Commit 3a1cb22

Browse files
jaycee-licopybara-github
authored andcommitted
fix: ResponseStream fails to parse error message and multi-line SSE data payloads
PiperOrigin-RevId: 900832738
1 parent 5cc9f78 commit 3a1cb22

6 files changed

Lines changed: 258 additions & 28 deletions

File tree

src/main/java/com/google/genai/ReplayApiClient.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,13 @@ public ApiResponse request(
135135
matchRequest(
136136
currentInteraction.request().orElse(null),
137137
buildRequest(httpMethod, path, requestJson, httpOptions));
138-
return buildResponseFromReplay(currentInteraction.response().orElse(null));
138+
boolean isStream =
139+
currentInteraction
140+
.request()
141+
.flatMap(r -> r.url())
142+
.map(u -> u.contains("streamGenerateContent"))
143+
.orElse(false);
144+
return buildResponseFromReplay(currentInteraction.response().orElse(null), isStream);
139145
} else {
140146
// Note that if the client mode is "api", then the ReplayApiClient will not be used.
141147
throw new IllegalArgumentException("Invalid client mode: " + this.clientMode);
@@ -227,15 +233,16 @@ private void matchRequest(ReplayRequest replayRequest, Request actualRequest) {
227233
}
228234

229235
/** Builds the response from a {@link ReplayResponse}. */
230-
private ReplayApiResponse buildResponseFromReplay(ReplayResponse replayResponse) {
236+
private ReplayApiResponse buildResponseFromReplay(
237+
ReplayResponse replayResponse, boolean isStream) {
231238
if (replayResponse == null) {
232239
throw new IllegalArgumentException("Replay response is null.");
233240
}
234241
JsonNode bodyNode =
235242
JsonSerializable.toJsonNode(replayResponse.bodySegments().orElse(new ArrayList<>()));
236243
Headers headers = Headers.of(replayResponse.headers().orElse(ImmutableMap.of()));
237244
return new ReplayApiResponse(
238-
(ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers);
245+
(ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers, isStream);
239246
}
240247

241248
private static String formatUrl(String url) {

src/main/java/com/google/genai/ReplayApiResponse.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,35 @@ public final class ReplayApiResponse extends ApiResponse {
3838
private final Headers headers;
3939
private final ArrayNode bodySegments;
4040

41-
public ReplayApiResponse(ArrayNode bodySegments, int statusCode, Headers headers) {
41+
public ReplayApiResponse(
42+
ArrayNode bodySegments, int statusCode, Headers headers, boolean isStream) {
4243
this.bodySegments = bodySegments;
4344
this.statusCode = statusCode;
4445
this.headers = headers;
4546
if (bodySegments.size() == 0) {
4647
this.body = ResponseBody.create(MediaType.parse("application/json"), "");
47-
} else if (bodySegments.size() == 1) {
48-
// For unary response
49-
this.body =
50-
ResponseBody.create(
51-
JsonSerializable.toJsonString(bodySegments.get(0)),
52-
MediaType.parse("application/json"));
53-
} else {
48+
} else if (isStream || bodySegments.size() > 1) {
5449
// For streaming response
5550
try {
5651
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
57-
byte[] newline = "\n".getBytes(StandardCharsets.UTF_8);
52+
byte[] dataPrefix = "data: ".getBytes(StandardCharsets.UTF_8);
53+
byte[] doubleNewline = "\n\n".getBytes(StandardCharsets.UTF_8);
5854
for (JsonNode segment : bodySegments) {
55+
outputStream.write(dataPrefix);
5956
outputStream.write(JsonSerializable.objectMapper.writeValueAsBytes(segment));
60-
outputStream.write(newline);
57+
outputStream.write(doubleNewline);
6158
}
6259
this.body =
6360
ResponseBody.create(outputStream.toByteArray(), MediaType.parse("application/json"));
6461
} catch (IOException e) {
6562
throw new GenAiIOException("Failed to convert body segments to a JSON string.", e);
6663
}
64+
} else {
65+
// For unary response
66+
this.body =
67+
ResponseBody.create(
68+
JsonSerializable.toJsonString(bodySegments.get(0)),
69+
MediaType.parse("application/json"));
6770
}
6871
}
6972

src/main/java/com/google/genai/ResponseStream.java

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2025 Google LLC
2+
* Copyright 2026 Google LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,7 +17,9 @@
1717
package com.google.genai;
1818

1919
import com.fasterxml.jackson.databind.JsonNode;
20+
import com.fasterxml.jackson.databind.node.ArrayNode;
2021
import com.fasterxml.jackson.databind.node.ObjectNode;
22+
import com.google.genai.errors.ApiException;
2123
import com.google.genai.errors.GenAiIOException;
2224
import java.io.BufferedReader;
2325
import java.io.IOException;
@@ -32,6 +34,7 @@
3234
import java.util.NoSuchElementException;
3335
import java.util.logging.Logger;
3436
import okhttp3.Headers;
37+
import org.jspecify.annotations.Nullable;
3538

3639
/** An iterable of datatype objects. */
3740
public class ResponseStream<T extends JsonSerializable> implements Iterable<T>, AutoCloseable {
@@ -114,6 +117,18 @@ public T next() {
114117
nextJson = readNextJson();
115118
try {
116119
JsonNode currentJsonNode = JsonSerializable.stringToJsonNode(currentJson);
120+
121+
if (currentJsonNode.isObject() && currentJsonNode.has("error")) {
122+
int extractedCode = 500;
123+
JsonNode errorNode = currentJsonNode.get("error");
124+
if (errorNode.has("code") && errorNode.get("code").isInt()) {
125+
extractedCode = errorNode.get("code").asInt();
126+
}
127+
ArrayNode arrayNode = JsonSerializable.objectMapper.createArrayNode();
128+
arrayNode.add(currentJsonNode);
129+
ApiException.throwFromErrorNode(arrayNode, extractedCode);
130+
}
131+
117132
if (responseHeaders != null && currentJsonNode.isObject()) {
118133
ObjectNode rootNode = (ObjectNode) currentJsonNode;
119134
ObjectNode headersNode = JsonSerializable.objectMapper.createObjectNode();
@@ -142,23 +157,47 @@ public T next() {
142157
}
143158
}
144159

145-
private String readNextJson() {
160+
private @Nullable String readNextJson() {
146161
// Streaming API returns in the following format:
147162
// data: {contents: ...}
148163
// \n
149164
// data: {contents: ...}
150165
// \n
151166
// ...
167+
List<String> dataBuffer = new ArrayList<>();
152168
try {
153-
String line = reader.readLine();
154-
if (line == null) {
155-
return null;
156-
} else if (line.length() == 0) {
157-
return readNextJson();
158-
} else if (line.startsWith("data: ")) {
159-
return line.substring("data: ".length());
160-
} else {
161-
return line;
169+
while (true) {
170+
String line = reader.readLine();
171+
if (line == null) {
172+
if (!dataBuffer.isEmpty()) {
173+
return String.join("\n", dataBuffer);
174+
}
175+
return null;
176+
}
177+
if (line.isEmpty()) {
178+
if (!dataBuffer.isEmpty()) {
179+
// Handle multi-line SSE data
180+
return String.join("\n", dataBuffer);
181+
}
182+
continue;
183+
}
184+
if (line.startsWith(":")) {
185+
continue;
186+
}
187+
int colonIndex = line.indexOf(':');
188+
String fieldname = line;
189+
String value = "";
190+
if (colonIndex != -1) {
191+
fieldname = line.substring(0, colonIndex);
192+
value = line.substring(colonIndex + 1);
193+
if (value.startsWith(" ")) {
194+
value = value.substring(1);
195+
}
196+
}
197+
198+
if (fieldname.equals("data")) {
199+
dataBuffer.add(value);
200+
}
162201
}
163202
} catch (IOException e) {
164203
throw new GenAiIOException("Failed to read next JSON object from the stream", e);

src/test/java/com/google/genai/AsyncChatTest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,16 @@ public class AsyncChatTest {
9696
String jsonChunk3 = responseChunk3.toJson();
9797

9898
String streamData =
99-
"data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n";
100-
String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n";
99+
"data: "
100+
+ jsonChunk1
101+
+ "\n\n"
102+
+ "data: "
103+
+ jsonChunk2
104+
+ "\n\n"
105+
+ "data: "
106+
+ jsonChunk3
107+
+ "\n\n";
108+
String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n";
101109

102110
GenerateContentResponse nonStreamingResponse =
103111
GenerateContentResponse.builder()

src/test/java/com/google/genai/ChatTest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,16 @@ public static String findTheaters(String movie, String location, String time) {
106106
String jsonChunk3 = responseChunk3.toJson();
107107

108108
String streamData =
109-
"data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n";
110-
String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n";
109+
"data: "
110+
+ jsonChunk1
111+
+ "\n\n"
112+
+ "data: "
113+
+ jsonChunk2
114+
+ "\n\n"
115+
+ "data: "
116+
+ jsonChunk3
117+
+ "\n\n";
118+
String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n";
111119

112120
GenerateContentResponse nonStreamingResponse =
113121
GenerateContentResponse.builder()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.genai;
18+
19+
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertTrue;
21+
22+
import com.fasterxml.jackson.databind.JsonNode;
23+
import com.fasterxml.jackson.databind.node.ObjectNode;
24+
import com.google.genai.types.Candidate;
25+
import com.google.genai.types.GenerateContentResponse;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.Iterator;
28+
import okhttp3.Headers;
29+
import okhttp3.MediaType;
30+
import okhttp3.ResponseBody;
31+
import org.junit.jupiter.api.Test;
32+
33+
public final class ResponseStreamTest {
34+
35+
public static class DummyConverter {
36+
public JsonNode convert(JsonNode fromObject, ObjectNode parentObject) {
37+
return fromObject;
38+
}
39+
}
40+
41+
@Test
42+
public void testMultiLineSseParsing() throws Exception {
43+
String sseData =
44+
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"line1\\n\"}]}},\n"
45+
+ "data: {\"content\": {\"parts\": [{\"text\": \"line2\"}]}}]}\n"
46+
+ "\n"; // End of event
47+
48+
ResponseBody body =
49+
ResponseBody.create(
50+
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
51+
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);
52+
53+
DummyConverter converter = new DummyConverter();
54+
ResponseStream<GenerateContentResponse> responseStream =
55+
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");
56+
57+
Iterator<GenerateContentResponse> iterator = responseStream.iterator();
58+
59+
assertTrue(iterator.hasNext());
60+
GenerateContentResponse response1 = iterator.next();
61+
62+
assertTrue(response1.candidates().isPresent());
63+
assertEquals(2, response1.candidates().get().size());
64+
65+
Candidate c1 = response1.candidates().get().get(0);
66+
assertEquals("line1\n", c1.content().get().text());
67+
68+
Candidate c2 = response1.candidates().get().get(1);
69+
assertEquals("line2", c2.content().get().text());
70+
71+
assertTrue(!iterator.hasNext());
72+
}
73+
74+
@Test
75+
public void testIgnoreNonSseLines() throws Exception {
76+
String sseData =
77+
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n"
78+
+ ": some comment line\n"
79+
+ "ignored field: some value\n"
80+
+ "\n"; // End of event
81+
82+
ResponseBody body =
83+
ResponseBody.create(
84+
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
85+
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);
86+
87+
DummyConverter converter = new DummyConverter();
88+
ResponseStream<GenerateContentResponse> responseStream =
89+
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");
90+
91+
Iterator<GenerateContentResponse> iterator = responseStream.iterator();
92+
93+
assertTrue(iterator.hasNext());
94+
GenerateContentResponse response1 = iterator.next();
95+
96+
assertEquals("valid data", response1.text());
97+
98+
assertTrue(!iterator.hasNext());
99+
}
100+
101+
@Test
102+
public void testStreamErrorHandling() throws Exception {
103+
String sseData =
104+
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n"
105+
+ "\n"
106+
+ "data: {\"error\": {\"code\": 429, \"message\": \"Quota exceeded\", \"status\": \"RESOURCE_EXHAUSTED\"}}\n"
107+
+ "\n";
108+
109+
ResponseBody body =
110+
ResponseBody.create(
111+
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
112+
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);
113+
114+
DummyConverter converter = new DummyConverter();
115+
ResponseStream<GenerateContentResponse> responseStream =
116+
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");
117+
118+
Iterator<GenerateContentResponse> iterator = responseStream.iterator();
119+
120+
assertTrue(iterator.hasNext());
121+
GenerateContentResponse response1 = iterator.next();
122+
assertEquals("valid data", response1.text());
123+
124+
assertTrue(iterator.hasNext());
125+
126+
try {
127+
iterator.next();
128+
org.junit.jupiter.api.Assertions.fail("Expected ApiException was not thrown");
129+
} catch (com.google.genai.errors.ApiException e) {
130+
assertEquals(429, e.code());
131+
assertEquals("RESOURCE_EXHAUSTED", e.status());
132+
assertTrue(e.getMessage().contains("Quota exceeded"));
133+
}
134+
}
135+
136+
@Test
137+
public void testMultipleValidEvents() throws Exception {
138+
String sseData =
139+
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk1\"}]}}]}\n"
140+
+ "\n"
141+
+ "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk2\"}]}}]}\n"
142+
+ "\n";
143+
144+
ResponseBody body =
145+
ResponseBody.create(
146+
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
147+
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);
148+
149+
DummyConverter converter = new DummyConverter();
150+
ResponseStream<GenerateContentResponse> responseStream =
151+
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");
152+
153+
Iterator<GenerateContentResponse> iterator = responseStream.iterator();
154+
155+
assertTrue(iterator.hasNext());
156+
GenerateContentResponse response1 = iterator.next();
157+
assertEquals("chunk1", response1.text());
158+
159+
assertTrue(iterator.hasNext());
160+
GenerateContentResponse response2 = iterator.next();
161+
assertEquals("chunk2", response2.text());
162+
163+
assertTrue(!iterator.hasNext());
164+
}
165+
}

0 commit comments

Comments
 (0)