Skip to content

Commit 951bc5d

Browse files
feng-95innsd
authored andcommitted
fix: override run_sse endpoint
(cherry picked from commit 71cb03ab3cad01bef50986fef7ddf162f8f3d2ef)
1 parent 3c30661 commit 951bc5d

1 file changed

Lines changed: 74 additions & 1 deletion

File tree

agentkit/apps/agent_server_app/agent_server_app.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from google.adk.auth.credential_service.in_memory_credential_service import (
3232
InMemoryCredentialService,
3333
)
34-
from google.adk.cli.adk_web_server import AdkWebServer
34+
from google.adk.cli.adk_web_server import AdkWebServer, RunAgentRequest
3535
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
3636
from google.adk.evaluation.local_eval_set_results_manager import (
3737
LocalEvalSetResultsManager,
@@ -149,6 +149,79 @@ async def lifespan(app: FastAPI):
149149

150150
self.app = self.server.get_fast_api_app(lifespan=lifespan)
151151

152+
@self.app.post("/run_sse")
153+
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
154+
logger.info("Overriding run_agent_sse endpoint...")
155+
# SSE endpoint
156+
session = await self.server.session_service.get_session(
157+
app_name=req.app_name,
158+
user_id=req.user_id,
159+
session_id=req.session_id,
160+
)
161+
if not session:
162+
raise HTTPException(status_code=404, detail="Session not found")
163+
164+
# Convert the events to properly formatted SSE
165+
async def event_generator():
166+
try:
167+
stream_mode = (
168+
StreamingMode.SSE
169+
if req.streaming
170+
else StreamingMode.NONE
171+
)
172+
runner = await self.server.get_runner_async(req.app_name)
173+
async with Aclosing(
174+
runner.run_async(
175+
user_id=req.user_id,
176+
session_id=req.session_id,
177+
new_message=req.new_message,
178+
state_delta=req.state_delta,
179+
run_config=RunConfig(streaming_mode=stream_mode),
180+
invocation_id=req.invocation_id,
181+
)
182+
) as agen:
183+
async for event in agen:
184+
# ADK Web renders artifacts from `actions.artifactDelta`
185+
# during part processing *and* during action processing
186+
# 1) the original event with `artifactDelta` cleared (content)
187+
# 2) a content-less "action-only" event carrying `artifactDelta`
188+
events_to_stream = [event]
189+
if (
190+
event.actions.artifact_delta
191+
and event.content
192+
and event.content.parts
193+
):
194+
content_event = event.model_copy(deep=True)
195+
content_event.actions.artifact_delta = {}
196+
artifact_event = event.model_copy(deep=True)
197+
artifact_event.content = None
198+
events_to_stream = [
199+
content_event,
200+
artifact_event,
201+
]
202+
for event_to_stream in events_to_stream:
203+
sse_event = event_to_stream.model_dump_json(
204+
exclude_none=True,
205+
by_alias=True,
206+
)
207+
logger.debug(
208+
"Generated event in agent run streaming: %s",
209+
sse_event,
210+
)
211+
yield f"data: {sse_event}\n\n"
212+
except Exception as e:
213+
logger.exception("Error in event_generator: %s", e)
214+
telemetry.trace_agent_server_finish(
215+
path="/invoke", func_result="", exception=e
216+
)
217+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
218+
# Returns a streaming response with the proper media type for SSE
219+
220+
return StreamingResponse(
221+
event_generator(),
222+
media_type="text/event-stream",
223+
)
224+
152225
# Attach ASGI middleware for unified telemetry across all routes
153226
self.app.add_middleware(AgentkitTelemetryHTTPMiddleware)
154227

0 commit comments

Comments
 (0)