Skip to content

Commit 190d8a0

Browse files
committed
fix review comments
1 parent 20d5469 commit 190d8a0

7 files changed

Lines changed: 205 additions & 4 deletions

File tree

src/agents/items.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,9 @@ class ToolApprovalItem(RunItemBase[Any]):
496496
tool_namespace: str | None = None
497497
"""Optional Responses API namespace for function-tool approvals."""
498498

499+
tool_origin: ToolOrigin | None = None
500+
"""Optional metadata describing where the approved tool call came from."""
501+
499502
tool_lookup_key: FunctionToolLookupKey | None = field(
500503
default=None,
501504
kw_only=True,

src/agents/run_internal/run_loop.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,14 @@
6969
RawResponsesStreamEvent,
7070
RunItemStreamEvent,
7171
)
72-
from ..tool import FunctionTool, Tool, dispose_resolved_computers, get_function_tool_origin
72+
from ..tool import (
73+
FunctionTool,
74+
Tool,
75+
ToolOrigin,
76+
ToolOriginType,
77+
dispose_resolved_computers,
78+
get_function_tool_origin,
79+
)
7380
from ..tracing import Span, SpanError, agent_span, get_current_trace, task_span, turn_span
7481
from ..tracing.model_tracing import get_model_tracing_impl
7582
from ..tracing.span_data import AgentSpanData, TaskSpanData
@@ -1563,6 +1570,10 @@ async def rewind_model_request() -> None:
15631570
if metadata is not None:
15641571
tool_description = metadata.description
15651572
tool_title = metadata.title
1573+
tool_origin = ToolOrigin(
1574+
type=ToolOriginType.MCP,
1575+
mcp_server_name=output_item.server_label,
1576+
)
15661577
elif matched_tool is not None:
15671578
tool_description = getattr(matched_tool, "description", None)
15681579
tool_title = getattr(matched_tool, "_mcp_title", None)

src/agents/run_internal/tool_execution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@ async def resolve_approval_status(
10651065
context_wrapper: RunContextWrapper[Any],
10661066
tool_namespace: str | None = None,
10671067
tool_lookup_key: FunctionToolLookupKey | None = None,
1068+
tool_origin: ToolOrigin | None = None,
10681069
on_approval: Callable[[RunContextWrapper[Any], ToolApprovalItem], Any] | None = None,
10691070
) -> tuple[bool | None, ToolApprovalItem]:
10701071
"""Build approval item, run on_approval hook if needed, and return latest approval status."""
@@ -1073,6 +1074,7 @@ async def resolve_approval_status(
10731074
raw_item=raw_item,
10741075
tool_name=tool_name,
10751076
tool_namespace=tool_namespace,
1077+
tool_origin=tool_origin,
10761078
tool_lookup_key=tool_lookup_key,
10771079
)
10781080
approval_status = context_wrapper.get_approval_status(
@@ -1612,6 +1614,7 @@ async def _maybe_execute_tool_approval(
16121614
raw_item=raw_tool_call,
16131615
tool_name=func_tool.name,
16141616
tool_namespace=tool_namespace,
1617+
tool_origin=get_function_tool_origin(func_tool),
16151618
tool_lookup_key=tool_lookup_key,
16161619
_allow_bare_name_alias=should_allow_bare_name_approval_alias(
16171620
func_tool,

src/agents/run_internal/tool_planning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ToolCallOutputItem,
2323
)
2424
from ..run_context import RunContextWrapper
25-
from ..tool import FunctionTool, MCPToolApprovalRequest
25+
from ..tool import FunctionTool, MCPToolApprovalRequest, get_function_tool_origin
2626
from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
2727
from .agent_bindings import AgentBindings
2828
from .run_steps import (
@@ -427,11 +427,17 @@ async def _collect_runs_by_approval(
427427
if approval_status is True:
428428
approved_runs.append(run)
429429
else:
430+
function_tool = get_mapping_or_attr(run, "function_tool")
430431
pending_item = existing_pending or ToolApprovalItem(
431432
agent=agent,
432433
raw_item=get_mapping_or_attr(run, "tool_call"),
433434
tool_name=tool_name,
434435
tool_namespace=get_tool_call_namespace(get_mapping_or_attr(run, "tool_call")),
436+
tool_origin=(
437+
get_function_tool_origin(function_tool)
438+
if isinstance(function_tool, FunctionTool)
439+
else None
440+
),
435441
tool_lookup_key=get_function_tool_lookup_key_for_call(
436442
get_mapping_or_attr(run, "tool_call")
437443
),

src/agents/run_internal/turn_resolution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
LocalShellTool,
8181
ShellTool,
8282
Tool,
83+
ToolOrigin,
84+
ToolOriginType,
8385
get_function_tool_origin,
8486
)
8587
from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
@@ -1158,6 +1160,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None:
11581160
raw_item=run.tool_call,
11591161
tool_name=run.function_tool.name,
11601162
tool_namespace=get_tool_call_namespace(run.tool_call),
1163+
tool_origin=get_function_tool_origin(run.function_tool),
11611164
tool_lookup_key=get_function_tool_lookup_key_for_call(run.tool_call),
11621165
_allow_bare_name_alias=should_allow_bare_name_approval_alias(
11631166
run.function_tool,
@@ -1669,6 +1672,10 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]:
16691672
agent=agent,
16701673
description=metadata.description if metadata is not None else None,
16711674
title=metadata.title if metadata is not None else None,
1675+
tool_origin=ToolOrigin(
1676+
type=ToolOriginType.MCP,
1677+
mcp_server_name=output.server_label,
1678+
),
16721679
)
16731680
)
16741681
tools_used.append("mcp")

src/agents/run_state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,8 @@ def _serialize_tool_approval_interruption(
13681368
interruption_dict["tool_name"] = interruption.tool_name
13691369
if interruption.tool_namespace is not None:
13701370
interruption_dict["tool_namespace"] = interruption.tool_namespace
1371+
if interruption.tool_origin is not None:
1372+
interruption_dict["tool_origin"] = interruption.tool_origin.to_json_dict()
13711373
tool_lookup_key = serialize_function_tool_lookup_key(
13721374
getattr(interruption, "tool_lookup_key", None)
13731375
)
@@ -2137,6 +2139,7 @@ def _deserialize_tool_approval_item(
21372139

21382140
tool_name = item_data.get("tool_name")
21392141
tool_namespace = item_data.get("tool_namespace")
2142+
tool_origin = _deserialize_tool_origin(item_data.get("tool_origin"))
21402143
tool_lookup_key = deserialize_function_tool_lookup_key(item_data.get("tool_lookup_key"))
21412144
allow_bare_name_alias = item_data.get("allow_bare_name_alias") is True
21422145
raw_item = _deserialize_tool_approval_raw_item(raw_item_data)
@@ -2145,6 +2148,7 @@ def _deserialize_tool_approval_item(
21452148
raw_item=raw_item,
21462149
tool_name=tool_name,
21472150
tool_namespace=tool_namespace,
2151+
tool_origin=tool_origin,
21482152
tool_lookup_key=tool_lookup_key,
21492153
_allow_bare_name_alias=allow_bare_name_alias,
21502154
)

tests/test_tool_origin.py

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import json
55
import weakref
66
from collections.abc import Sequence
7-
from typing import Any, TypeVar
7+
from typing import Any, TypeVar, cast
88

99
import pytest
1010
from mcp import Tool as MCPTool
11+
from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool
1112
from pydantic import BaseModel
1213

1314
from agents import (
1415
Agent,
16+
HostedMCPTool,
1517
ModelResponse,
1618
RunConfig,
1719
RunContextWrapper,
@@ -25,8 +27,10 @@
2527
Usage,
2628
function_tool,
2729
)
30+
from agents.items import MCPListToolsItem, ToolApprovalItem
2831
from agents.mcp import MCPUtil
2932
from agents.run_internal import run_loop
33+
from agents.run_internal.agent_bindings import bind_public_agent
3034
from agents.run_internal.run_loop import get_output_schema
3135
from agents.run_internal.tool_execution import execute_function_tool_calls
3236
from tests.fake_model import FakeModel
@@ -48,6 +52,22 @@ class StructuredOutputPayload(BaseModel):
4852
status: str
4953

5054

55+
def _make_hosted_mcp_list_tools(server_label: str, tool_name: str) -> McpListTools:
56+
return McpListTools(
57+
id=f"list_{server_label}",
58+
server_label=server_label,
59+
tools=[
60+
McpListToolsTool(
61+
name=tool_name,
62+
input_schema={},
63+
description="Search the docs.",
64+
annotations={"title": "Search Docs"},
65+
)
66+
],
67+
type="mcp_list_tools",
68+
)
69+
70+
5171
@pytest.mark.asyncio
5272
async def test_runner_attaches_function_tool_origin_to_call_and_output_items() -> None:
5373
model = FakeModel()
@@ -184,6 +204,104 @@ async def test_streamed_tool_call_item_includes_local_mcp_origin() -> None:
184204
)
185205

186206

207+
def test_process_model_response_attaches_hosted_mcp_tool_origin() -> None:
208+
agent = Agent(name="hosted-mcp")
209+
hosted_tool = HostedMCPTool(
210+
tool_config=cast(
211+
Any,
212+
{
213+
"type": "mcp",
214+
"server_label": "docs_server",
215+
"server_url": "https://example.com/mcp",
216+
},
217+
)
218+
)
219+
existing_items = [
220+
MCPListToolsItem(
221+
agent=agent,
222+
raw_item=_make_hosted_mcp_list_tools("docs_server", "search_docs"),
223+
)
224+
]
225+
response = ModelResponse(
226+
output=[
227+
McpCall(
228+
id="mcp_call_1",
229+
arguments="{}",
230+
name="search_docs",
231+
server_label="docs_server",
232+
type="mcp_call",
233+
status="completed",
234+
)
235+
],
236+
usage=Usage(),
237+
response_id="resp_hosted_mcp",
238+
)
239+
240+
processed = run_loop.process_model_response(
241+
agent=agent,
242+
all_tools=[hosted_tool],
243+
response=response,
244+
output_schema=None,
245+
handoffs=[],
246+
existing_items=existing_items,
247+
)
248+
249+
tool_call_item = _first_item(processed.new_items, ToolCallItem)
250+
assert tool_call_item.tool_origin == ToolOrigin(
251+
type=ToolOriginType.MCP,
252+
mcp_server_name="docs_server",
253+
)
254+
255+
256+
@pytest.mark.asyncio
257+
async def test_streamed_tool_call_item_includes_hosted_mcp_origin() -> None:
258+
model = FakeModel()
259+
hosted_tool = HostedMCPTool(
260+
tool_config=cast(
261+
Any,
262+
{
263+
"type": "mcp",
264+
"server_label": "docs_server",
265+
"server_url": "https://example.com/mcp",
266+
},
267+
)
268+
)
269+
agent = Agent(name="stream-hosted-mcp", model=model, tools=[hosted_tool])
270+
model.add_multiple_turn_outputs(
271+
[
272+
[
273+
_make_hosted_mcp_list_tools("docs_server", "search_docs"),
274+
McpCall(
275+
id="mcp_call_stream_1",
276+
arguments="{}",
277+
name="search_docs",
278+
server_label="docs_server",
279+
type="mcp_call",
280+
status="completed",
281+
),
282+
],
283+
[get_text_message("done")],
284+
]
285+
)
286+
287+
result = Runner.run_streamed(agent, input="hello")
288+
seen_tool_item: ToolCallItem | None = None
289+
async for event in result.stream_events():
290+
if (
291+
event.type == "run_item_stream_event"
292+
and isinstance(event.item, ToolCallItem)
293+
and isinstance(event.item.raw_item, McpCall)
294+
):
295+
seen_tool_item = event.item
296+
break
297+
298+
assert seen_tool_item is not None
299+
assert seen_tool_item.tool_origin == ToolOrigin(
300+
type=ToolOriginType.MCP,
301+
mcp_server_name="docs_server",
302+
)
303+
304+
187305
def test_local_mcp_tool_origin_does_not_retain_server_object() -> None:
188306
server = FakeMCPServer(server_name="docs_server")
189307
function_tool = MCPUtil.to_function_tool(
@@ -245,7 +363,7 @@ async def test_json_tool_call_does_not_emit_function_tool_origin() -> None:
245363
assert tool_call_item.tool_origin is None
246364

247365
function_results, _, _ = await execute_function_tool_calls(
248-
agent=agent,
366+
bindings=bind_public_agent(agent),
249367
tool_runs=processed.functions,
250368
hooks=RunHooks(),
251369
context_wrapper=context_wrapper,
@@ -332,3 +450,52 @@ async def test_run_state_from_json_reads_legacy_1_5_without_tool_origin() -> Non
332450
restored_item = _first_item(restored._generated_items, ToolCallItem)
333451
assert restored_item.description == "Legacy tool"
334452
assert restored_item.tool_origin is None
453+
454+
455+
@pytest.mark.asyncio
456+
async def test_run_state_roundtrip_preserves_tool_origin_on_approval_interruptions() -> None:
457+
agent = Agent(name="approval-origin")
458+
state: RunState[Any, Agent[Any]] = make_run_state(agent)
459+
state._generated_items.append(
460+
ToolApprovalItem(
461+
agent=agent,
462+
raw_item=make_tool_call(call_id="call_approval", name="approval_tool"),
463+
tool_name="approval_tool",
464+
tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION),
465+
)
466+
)
467+
468+
restored = await roundtrip_state(agent, state)
469+
470+
approval_item = _first_item(restored._generated_items, ToolApprovalItem)
471+
assert approval_item.tool_origin == ToolOrigin(type=ToolOriginType.FUNCTION)
472+
473+
474+
@pytest.mark.asyncio
475+
async def test_run_state_from_json_reads_legacy_1_6_approval_without_tool_origin() -> None:
476+
agent = Agent(name="approval-origin-legacy")
477+
state: RunState[Any, Agent[Any]] = make_run_state(agent)
478+
state._generated_items.append(
479+
ToolApprovalItem(
480+
agent=agent,
481+
raw_item=make_tool_call(call_id="call_legacy_approval", name="approval_tool"),
482+
tool_name="approval_tool",
483+
tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION),
484+
)
485+
)
486+
487+
restored = await roundtrip_state(
488+
agent,
489+
state,
490+
mutate_json=lambda data: {
491+
**data,
492+
"$schemaVersion": "1.6",
493+
"generated_items": [
494+
{key: value for key, value in item.items() if key != "tool_origin"}
495+
for item in data["generated_items"]
496+
],
497+
},
498+
)
499+
500+
approval_item = _first_item(restored._generated_items, ToolApprovalItem)
501+
assert approval_item.tool_origin is None

0 commit comments

Comments
 (0)