Skip to content

Commit 5e664c6

Browse files
committed
fix review comments
1 parent 8ab7335 commit 5e664c6

7 files changed

Lines changed: 203 additions & 3 deletions

File tree

src/agents/items.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ class ToolApprovalItem(RunItemBase[Any]):
500500
tool_namespace: str | None = None
501501
"""Optional Responses API namespace for function-tool approvals."""
502502

503+
tool_origin: ToolOrigin | None = None
504+
"""Optional metadata describing where the approved tool call came from."""
505+
503506
tool_lookup_key: FunctionToolLookupKey | None = field(
504507
default=None,
505508
kw_only=True,

src/agents/run_internal/run_loop.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@
6767
RawResponsesStreamEvent,
6868
RunItemStreamEvent,
6969
)
70-
from ..tool import FunctionTool, Tool, dispose_resolved_computers, get_function_tool_origin
70+
from ..tool import (
71+
FunctionTool,
72+
Tool,
73+
ToolOrigin,
74+
ToolOriginType,
75+
dispose_resolved_computers,
76+
get_function_tool_origin,
77+
)
7178
from ..tracing import Span, SpanError, agent_span, get_current_trace
7279
from ..tracing.model_tracing import get_model_tracing_impl
7380
from ..tracing.span_data import AgentSpanData
@@ -1374,6 +1381,10 @@ async def rewind_model_request() -> None:
13741381
if metadata is not None:
13751382
tool_description = metadata.description
13761383
tool_title = metadata.title
1384+
tool_origin = ToolOrigin(
1385+
type=ToolOriginType.MCP,
1386+
mcp_server_name=output_item.server_label,
1387+
)
13771388
elif matched_tool is not None:
13781389
tool_description = getattr(matched_tool, "description", None)
13791390
tool_title = getattr(matched_tool, "_mcp_title", None)

src/agents/run_internal/tool_execution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,7 @@ async def resolve_approval_status(
995995
context_wrapper: RunContextWrapper[Any],
996996
tool_namespace: str | None = None,
997997
tool_lookup_key: FunctionToolLookupKey | None = None,
998+
tool_origin: ToolOrigin | None = None,
998999
on_approval: Callable[[RunContextWrapper[Any], ToolApprovalItem], Any] | None = None,
9991000
) -> tuple[bool | None, ToolApprovalItem]:
10001001
"""Build approval item, run on_approval hook if needed, and return latest approval status."""
@@ -1003,6 +1004,7 @@ async def resolve_approval_status(
10031004
raw_item=raw_item,
10041005
tool_name=tool_name,
10051006
tool_namespace=tool_namespace,
1007+
tool_origin=tool_origin,
10061008
tool_lookup_key=tool_lookup_key,
10071009
)
10081010
approval_status = context_wrapper.get_approval_status(
@@ -1506,6 +1508,7 @@ async def _maybe_execute_tool_approval(
15061508
raw_item=raw_tool_call,
15071509
tool_name=func_tool.name,
15081510
tool_namespace=tool_namespace,
1511+
tool_origin=get_function_tool_origin(func_tool),
15091512
tool_lookup_key=tool_lookup_key,
15101513
_allow_bare_name_alias=should_allow_bare_name_approval_alias(
15111514
func_tool,

src/agents/run_internal/tool_planning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ToolCallOutputItem,
2323
)
2424
from ..run_context import RunContextWrapper
25-
from ..tool import FunctionTool, MCPToolApprovalRequest
25+
from ..tool import FunctionTool, MCPToolApprovalRequest, get_function_tool_origin
2626
from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
2727
from .run_steps import (
2828
ToolRunApplyPatchCall,
@@ -410,11 +410,17 @@ async def _collect_runs_by_approval(
410410
if approval_status is True:
411411
approved_runs.append(run)
412412
else:
413+
function_tool = get_mapping_or_attr(run, "function_tool")
413414
pending_item = existing_pending or ToolApprovalItem(
414415
agent=agent,
415416
raw_item=get_mapping_or_attr(run, "tool_call"),
416417
tool_name=tool_name,
417418
tool_namespace=get_tool_call_namespace(get_mapping_or_attr(run, "tool_call")),
419+
tool_origin=(
420+
get_function_tool_origin(function_tool)
421+
if isinstance(function_tool, FunctionTool)
422+
else None
423+
),
418424
tool_lookup_key=get_function_tool_lookup_key_for_call(
419425
get_mapping_or_attr(run, "tool_call")
420426
),

src/agents/run_internal/turn_resolution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
LocalShellTool,
7979
ShellTool,
8080
Tool,
81+
ToolOrigin,
82+
ToolOriginType,
8183
get_function_tool_origin,
8284
)
8385
from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
@@ -1031,6 +1033,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None:
10311033
raw_item=run.tool_call,
10321034
tool_name=run.function_tool.name,
10331035
tool_namespace=get_tool_call_namespace(run.tool_call),
1036+
tool_origin=get_function_tool_origin(run.function_tool),
10341037
tool_lookup_key=get_function_tool_lookup_key_for_call(run.tool_call),
10351038
_allow_bare_name_alias=should_allow_bare_name_approval_alias(
10361039
run.function_tool,
@@ -1520,6 +1523,10 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]:
15201523
agent=agent,
15211524
description=metadata.description if metadata is not None else None,
15221525
title=metadata.title if metadata is not None else None,
1526+
tool_origin=ToolOrigin(
1527+
type=ToolOriginType.MCP,
1528+
mcp_server_name=output.server_label,
1529+
),
15231530
)
15241531
)
15251532
tools_used.append("mcp")

src/agents/run_state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,8 @@ def _serialize_tool_approval_interruption(
12141214
interruption_dict["tool_name"] = interruption.tool_name
12151215
if interruption.tool_namespace is not None:
12161216
interruption_dict["tool_namespace"] = interruption.tool_namespace
1217+
if interruption.tool_origin is not None:
1218+
interruption_dict["tool_origin"] = interruption.tool_origin.to_json_dict()
12171219
tool_lookup_key = serialize_function_tool_lookup_key(
12181220
getattr(interruption, "tool_lookup_key", None)
12191221
)
@@ -1885,6 +1887,7 @@ def _deserialize_tool_approval_item(
18851887

18861888
tool_name = item_data.get("tool_name")
18871889
tool_namespace = item_data.get("tool_namespace")
1890+
tool_origin = _deserialize_tool_origin(item_data.get("tool_origin"))
18881891
tool_lookup_key = deserialize_function_tool_lookup_key(item_data.get("tool_lookup_key"))
18891892
allow_bare_name_alias = item_data.get("allow_bare_name_alias") is True
18901893
raw_item = _deserialize_tool_approval_raw_item(raw_item_data)
@@ -1893,6 +1896,7 @@ def _deserialize_tool_approval_item(
18931896
raw_item=raw_item,
18941897
tool_name=tool_name,
18951898
tool_namespace=tool_namespace,
1899+
tool_origin=tool_origin,
18961900
tool_lookup_key=tool_lookup_key,
18971901
_allow_bare_name_alias=allow_bare_name_alias,
18981902
)

tests/test_tool_origin.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import json
55
import weakref
66
from collections.abc import Sequence
7-
from typing import Any, TypeVar
7+
from typing import Any, TypeVar, cast
88

99
import pytest
1010
from mcp import Tool as MCPTool
11+
from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool
1112
from pydantic import BaseModel
1213

1314
from agents import (
1415
Agent,
16+
HostedMCPTool,
1517
ModelResponse,
1618
RunConfig,
1719
RunContextWrapper,
@@ -25,6 +27,7 @@
2527
Usage,
2628
function_tool,
2729
)
30+
from agents.items import MCPListToolsItem, ToolApprovalItem
2831
from agents.mcp import MCPUtil
2932
from agents.run_internal import run_loop
3033
from agents.run_internal.run_loop import get_output_schema
@@ -48,6 +51,22 @@ class StructuredOutputPayload(BaseModel):
4851
status: str
4952

5053

54+
def _make_hosted_mcp_list_tools(server_label: str, tool_name: str) -> McpListTools:
55+
return McpListTools(
56+
id=f"list_{server_label}",
57+
server_label=server_label,
58+
tools=[
59+
McpListToolsTool(
60+
name=tool_name,
61+
input_schema={},
62+
description="Search the docs.",
63+
annotations={"title": "Search Docs"},
64+
)
65+
],
66+
type="mcp_list_tools",
67+
)
68+
69+
5170
@pytest.mark.asyncio
5271
async def test_runner_attaches_function_tool_origin_to_call_and_output_items() -> None:
5372
model = FakeModel()
@@ -184,6 +203,104 @@ async def test_streamed_tool_call_item_includes_local_mcp_origin() -> None:
184203
)
185204

186205

206+
def test_process_model_response_attaches_hosted_mcp_tool_origin() -> None:
207+
agent = Agent(name="hosted-mcp")
208+
hosted_tool = HostedMCPTool(
209+
tool_config=cast(
210+
Any,
211+
{
212+
"type": "mcp",
213+
"server_label": "docs_server",
214+
"server_url": "https://example.com/mcp",
215+
},
216+
)
217+
)
218+
existing_items = [
219+
MCPListToolsItem(
220+
agent=agent,
221+
raw_item=_make_hosted_mcp_list_tools("docs_server", "search_docs"),
222+
)
223+
]
224+
response = ModelResponse(
225+
output=[
226+
McpCall(
227+
id="mcp_call_1",
228+
arguments="{}",
229+
name="search_docs",
230+
server_label="docs_server",
231+
type="mcp_call",
232+
status="completed",
233+
)
234+
],
235+
usage=Usage(),
236+
response_id="resp_hosted_mcp",
237+
)
238+
239+
processed = run_loop.process_model_response(
240+
agent=agent,
241+
all_tools=[hosted_tool],
242+
response=response,
243+
output_schema=None,
244+
handoffs=[],
245+
existing_items=existing_items,
246+
)
247+
248+
tool_call_item = _first_item(processed.new_items, ToolCallItem)
249+
assert tool_call_item.tool_origin == ToolOrigin(
250+
type=ToolOriginType.MCP,
251+
mcp_server_name="docs_server",
252+
)
253+
254+
255+
@pytest.mark.asyncio
256+
async def test_streamed_tool_call_item_includes_hosted_mcp_origin() -> None:
257+
model = FakeModel()
258+
hosted_tool = HostedMCPTool(
259+
tool_config=cast(
260+
Any,
261+
{
262+
"type": "mcp",
263+
"server_label": "docs_server",
264+
"server_url": "https://example.com/mcp",
265+
},
266+
)
267+
)
268+
agent = Agent(name="stream-hosted-mcp", model=model, tools=[hosted_tool])
269+
model.add_multiple_turn_outputs(
270+
[
271+
[
272+
_make_hosted_mcp_list_tools("docs_server", "search_docs"),
273+
McpCall(
274+
id="mcp_call_stream_1",
275+
arguments="{}",
276+
name="search_docs",
277+
server_label="docs_server",
278+
type="mcp_call",
279+
status="completed",
280+
),
281+
],
282+
[get_text_message("done")],
283+
]
284+
)
285+
286+
result = Runner.run_streamed(agent, input="hello")
287+
seen_tool_item: ToolCallItem | None = None
288+
async for event in result.stream_events():
289+
if (
290+
event.type == "run_item_stream_event"
291+
and isinstance(event.item, ToolCallItem)
292+
and isinstance(event.item.raw_item, McpCall)
293+
):
294+
seen_tool_item = event.item
295+
break
296+
297+
assert seen_tool_item is not None
298+
assert seen_tool_item.tool_origin == ToolOrigin(
299+
type=ToolOriginType.MCP,
300+
mcp_server_name="docs_server",
301+
)
302+
303+
187304
def test_local_mcp_tool_origin_does_not_retain_server_object() -> None:
188305
server = FakeMCPServer(server_name="docs_server")
189306
function_tool = MCPUtil.to_function_tool(
@@ -332,3 +449,52 @@ async def test_run_state_from_json_reads_legacy_1_5_without_tool_origin() -> Non
332449
restored_item = _first_item(restored._generated_items, ToolCallItem)
333450
assert restored_item.description == "Legacy tool"
334451
assert restored_item.tool_origin is None
452+
453+
454+
@pytest.mark.asyncio
455+
async def test_run_state_roundtrip_preserves_tool_origin_on_approval_interruptions() -> None:
456+
agent = Agent(name="approval-origin")
457+
state: RunState[Any, Agent[Any]] = make_run_state(agent)
458+
state._generated_items.append(
459+
ToolApprovalItem(
460+
agent=agent,
461+
raw_item=make_tool_call(call_id="call_approval", name="approval_tool"),
462+
tool_name="approval_tool",
463+
tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION),
464+
)
465+
)
466+
467+
restored = await roundtrip_state(agent, state)
468+
469+
approval_item = _first_item(restored._generated_items, ToolApprovalItem)
470+
assert approval_item.tool_origin == ToolOrigin(type=ToolOriginType.FUNCTION)
471+
472+
473+
@pytest.mark.asyncio
474+
async def test_run_state_from_json_reads_legacy_1_6_approval_without_tool_origin() -> None:
475+
agent = Agent(name="approval-origin-legacy")
476+
state: RunState[Any, Agent[Any]] = make_run_state(agent)
477+
state._generated_items.append(
478+
ToolApprovalItem(
479+
agent=agent,
480+
raw_item=make_tool_call(call_id="call_legacy_approval", name="approval_tool"),
481+
tool_name="approval_tool",
482+
tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION),
483+
)
484+
)
485+
486+
restored = await roundtrip_state(
487+
agent,
488+
state,
489+
mutate_json=lambda data: {
490+
**data,
491+
"$schemaVersion": "1.6",
492+
"generated_items": [
493+
{key: value for key, value in item.items() if key != "tool_origin"}
494+
for item in data["generated_items"]
495+
],
496+
},
497+
)
498+
499+
approval_item = _first_item(restored._generated_items, ToolApprovalItem)
500+
assert approval_item.tool_origin is None

0 commit comments

Comments
 (0)