44import json
55import weakref
66from collections .abc import Sequence
7- from typing import Any , TypeVar
7+ from typing import Any , TypeVar , cast
88
99import pytest
1010from mcp import Tool as MCPTool
11+ from openai .types .responses .response_output_item import McpCall , McpListTools , McpListToolsTool
1112from pydantic import BaseModel
1213
1314from agents import (
1415 Agent ,
16+ HostedMCPTool ,
1517 ModelResponse ,
1618 RunConfig ,
1719 RunContextWrapper ,
2527 Usage ,
2628 function_tool ,
2729)
30+ from agents .items import MCPListToolsItem , ToolApprovalItem
2831from agents .mcp import MCPUtil
2932from agents .run_internal import run_loop
3033from agents .run_internal .run_loop import get_output_schema
@@ -48,6 +51,22 @@ class StructuredOutputPayload(BaseModel):
4851 status : str
4952
5053
54+ def _make_hosted_mcp_list_tools (server_label : str , tool_name : str ) -> McpListTools :
55+ return McpListTools (
56+ id = f"list_{ server_label } " ,
57+ server_label = server_label ,
58+ tools = [
59+ McpListToolsTool (
60+ name = tool_name ,
61+ input_schema = {},
62+ description = "Search the docs." ,
63+ annotations = {"title" : "Search Docs" },
64+ )
65+ ],
66+ type = "mcp_list_tools" ,
67+ )
68+
69+
5170@pytest .mark .asyncio
5271async def test_runner_attaches_function_tool_origin_to_call_and_output_items () -> None :
5372 model = FakeModel ()
@@ -184,6 +203,104 @@ async def test_streamed_tool_call_item_includes_local_mcp_origin() -> None:
184203 )
185204
186205
206+ def test_process_model_response_attaches_hosted_mcp_tool_origin () -> None :
207+ agent = Agent (name = "hosted-mcp" )
208+ hosted_tool = HostedMCPTool (
209+ tool_config = cast (
210+ Any ,
211+ {
212+ "type" : "mcp" ,
213+ "server_label" : "docs_server" ,
214+ "server_url" : "https://example.com/mcp" ,
215+ },
216+ )
217+ )
218+ existing_items = [
219+ MCPListToolsItem (
220+ agent = agent ,
221+ raw_item = _make_hosted_mcp_list_tools ("docs_server" , "search_docs" ),
222+ )
223+ ]
224+ response = ModelResponse (
225+ output = [
226+ McpCall (
227+ id = "mcp_call_1" ,
228+ arguments = "{}" ,
229+ name = "search_docs" ,
230+ server_label = "docs_server" ,
231+ type = "mcp_call" ,
232+ status = "completed" ,
233+ )
234+ ],
235+ usage = Usage (),
236+ response_id = "resp_hosted_mcp" ,
237+ )
238+
239+ processed = run_loop .process_model_response (
240+ agent = agent ,
241+ all_tools = [hosted_tool ],
242+ response = response ,
243+ output_schema = None ,
244+ handoffs = [],
245+ existing_items = existing_items ,
246+ )
247+
248+ tool_call_item = _first_item (processed .new_items , ToolCallItem )
249+ assert tool_call_item .tool_origin == ToolOrigin (
250+ type = ToolOriginType .MCP ,
251+ mcp_server_name = "docs_server" ,
252+ )
253+
254+
255+ @pytest .mark .asyncio
256+ async def test_streamed_tool_call_item_includes_hosted_mcp_origin () -> None :
257+ model = FakeModel ()
258+ hosted_tool = HostedMCPTool (
259+ tool_config = cast (
260+ Any ,
261+ {
262+ "type" : "mcp" ,
263+ "server_label" : "docs_server" ,
264+ "server_url" : "https://example.com/mcp" ,
265+ },
266+ )
267+ )
268+ agent = Agent (name = "stream-hosted-mcp" , model = model , tools = [hosted_tool ])
269+ model .add_multiple_turn_outputs (
270+ [
271+ [
272+ _make_hosted_mcp_list_tools ("docs_server" , "search_docs" ),
273+ McpCall (
274+ id = "mcp_call_stream_1" ,
275+ arguments = "{}" ,
276+ name = "search_docs" ,
277+ server_label = "docs_server" ,
278+ type = "mcp_call" ,
279+ status = "completed" ,
280+ ),
281+ ],
282+ [get_text_message ("done" )],
283+ ]
284+ )
285+
286+ result = Runner .run_streamed (agent , input = "hello" )
287+ seen_tool_item : ToolCallItem | None = None
288+ async for event in result .stream_events ():
289+ if (
290+ event .type == "run_item_stream_event"
291+ and isinstance (event .item , ToolCallItem )
292+ and isinstance (event .item .raw_item , McpCall )
293+ ):
294+ seen_tool_item = event .item
295+ break
296+
297+ assert seen_tool_item is not None
298+ assert seen_tool_item .tool_origin == ToolOrigin (
299+ type = ToolOriginType .MCP ,
300+ mcp_server_name = "docs_server" ,
301+ )
302+
303+
187304def test_local_mcp_tool_origin_does_not_retain_server_object () -> None :
188305 server = FakeMCPServer (server_name = "docs_server" )
189306 function_tool = MCPUtil .to_function_tool (
@@ -332,3 +449,52 @@ async def test_run_state_from_json_reads_legacy_1_5_without_tool_origin() -> Non
332449 restored_item = _first_item (restored ._generated_items , ToolCallItem )
333450 assert restored_item .description == "Legacy tool"
334451 assert restored_item .tool_origin is None
452+
453+
454+ @pytest .mark .asyncio
455+ async def test_run_state_roundtrip_preserves_tool_origin_on_approval_interruptions () -> None :
456+ agent = Agent (name = "approval-origin" )
457+ state : RunState [Any , Agent [Any ]] = make_run_state (agent )
458+ state ._generated_items .append (
459+ ToolApprovalItem (
460+ agent = agent ,
461+ raw_item = make_tool_call (call_id = "call_approval" , name = "approval_tool" ),
462+ tool_name = "approval_tool" ,
463+ tool_origin = ToolOrigin (type = ToolOriginType .FUNCTION ),
464+ )
465+ )
466+
467+ restored = await roundtrip_state (agent , state )
468+
469+ approval_item = _first_item (restored ._generated_items , ToolApprovalItem )
470+ assert approval_item .tool_origin == ToolOrigin (type = ToolOriginType .FUNCTION )
471+
472+
473+ @pytest .mark .asyncio
474+ async def test_run_state_from_json_reads_legacy_1_6_approval_without_tool_origin () -> None :
475+ agent = Agent (name = "approval-origin-legacy" )
476+ state : RunState [Any , Agent [Any ]] = make_run_state (agent )
477+ state ._generated_items .append (
478+ ToolApprovalItem (
479+ agent = agent ,
480+ raw_item = make_tool_call (call_id = "call_legacy_approval" , name = "approval_tool" ),
481+ tool_name = "approval_tool" ,
482+ tool_origin = ToolOrigin (type = ToolOriginType .FUNCTION ),
483+ )
484+ )
485+
486+ restored = await roundtrip_state (
487+ agent ,
488+ state ,
489+ mutate_json = lambda data : {
490+ ** data ,
491+ "$schemaVersion" : "1.6" ,
492+ "generated_items" : [
493+ {key : value for key , value in item .items () if key != "tool_origin" }
494+ for item in data ["generated_items" ]
495+ ],
496+ },
497+ )
498+
499+ approval_item = _first_item (restored ._generated_items , ToolApprovalItem )
500+ assert approval_item .tool_origin is None
0 commit comments