|
31 | 31 | from google.adk.auth.credential_service.in_memory_credential_service import ( |
32 | 32 | InMemoryCredentialService, |
33 | 33 | ) |
34 | | -from google.adk.cli.adk_web_server import AdkWebServer |
| 34 | +from google.adk.cli.adk_web_server import AdkWebServer, RunAgentRequest |
35 | 35 | from google.adk.cli.utils.base_agent_loader import BaseAgentLoader |
36 | 36 | from google.adk.evaluation.local_eval_set_results_manager import ( |
37 | 37 | LocalEvalSetResultsManager, |
@@ -149,6 +149,79 @@ async def lifespan(app: FastAPI): |
149 | 149 |
|
150 | 150 | self.app = self.server.get_fast_api_app(lifespan=lifespan) |
151 | 151 |
|
| 152 | + @self.app.post("/run_sse") |
| 153 | + async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: |
| 154 | + logger.info("Overriding run_agent_sse endpoint...") |
| 155 | + # SSE endpoint |
| 156 | + session = await self.server.session_service.get_session( |
| 157 | + app_name=req.app_name, |
| 158 | + user_id=req.user_id, |
| 159 | + session_id=req.session_id, |
| 160 | + ) |
| 161 | + if not session: |
| 162 | + raise HTTPException(status_code=404, detail="Session not found") |
| 163 | + |
| 164 | + # Convert the events to properly formatted SSE |
| 165 | + async def event_generator(): |
| 166 | + try: |
| 167 | + stream_mode = ( |
| 168 | + StreamingMode.SSE |
| 169 | + if req.streaming |
| 170 | + else StreamingMode.NONE |
| 171 | + ) |
| 172 | + runner = await self.server.get_runner_async(req.app_name) |
| 173 | + async with Aclosing( |
| 174 | + runner.run_async( |
| 175 | + user_id=req.user_id, |
| 176 | + session_id=req.session_id, |
| 177 | + new_message=req.new_message, |
| 178 | + state_delta=req.state_delta, |
| 179 | + run_config=RunConfig(streaming_mode=stream_mode), |
| 180 | + invocation_id=req.invocation_id, |
| 181 | + ) |
| 182 | + ) as agen: |
| 183 | + async for event in agen: |
| 184 | + # ADK Web renders artifacts from `actions.artifactDelta` |
| 185 | + # during part processing *and* during action processing |
| 186 | + # 1) the original event with `artifactDelta` cleared (content) |
| 187 | + # 2) a content-less "action-only" event carrying `artifactDelta` |
| 188 | + events_to_stream = [event] |
| 189 | + if ( |
| 190 | + event.actions.artifact_delta |
| 191 | + and event.content |
| 192 | + and event.content.parts |
| 193 | + ): |
| 194 | + content_event = event.model_copy(deep=True) |
| 195 | + content_event.actions.artifact_delta = {} |
| 196 | + artifact_event = event.model_copy(deep=True) |
| 197 | + artifact_event.content = None |
| 198 | + events_to_stream = [ |
| 199 | + content_event, |
| 200 | + artifact_event, |
| 201 | + ] |
| 202 | + for event_to_stream in events_to_stream: |
| 203 | + sse_event = event_to_stream.model_dump_json( |
| 204 | + exclude_none=True, |
| 205 | + by_alias=True, |
| 206 | + ) |
| 207 | + logger.debug( |
| 208 | + "Generated event in agent run streaming: %s", |
| 209 | + sse_event, |
| 210 | + ) |
| 211 | + yield f"data: {sse_event}\n\n" |
| 212 | + except Exception as e: |
| 213 | + logger.exception("Error in event_generator: %s", e) |
| 214 | + telemetry.trace_agent_server_finish( |
| 215 | + path="/invoke", func_result="", exception=e |
| 216 | + ) |
| 217 | + yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| 218 | + # Returns a streaming response with the proper media type for SSE |
| 219 | + |
| 220 | + return StreamingResponse( |
| 221 | + event_generator(), |
| 222 | + media_type="text/event-stream", |
| 223 | + ) |
| 224 | + |
152 | 225 | # Attach ASGI middleware for unified telemetry across all routes |
153 | 226 | self.app.add_middleware(AgentkitTelemetryHTTPMiddleware) |
154 | 227 |
|
|
0 commit comments