Skip to content

Commit e345584

Browse files
authored
fix(tool_registry): Support plain dict as return types of local tools (#141)
- Add reproducer tool that returns a dict - Fix the bug
1 parent fd70297 commit e345584

5 files changed

Lines changed: 128 additions & 2 deletions

File tree

splunklib/ai/registry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import string
1919
from collections.abc import Callable, Sequence
20-
from dataclasses import asdict, dataclass
20+
from dataclasses import asdict, dataclass, is_dataclass
2121
from logging import Logger
2222
from typing import (
2323
Any,
@@ -317,8 +317,14 @@ async def _call_tool(
317317
if self._tools_wrapped_result.get(name):
318318
res = _WrappedResult(res)
319319

320+
if is_dataclass(res) and not isinstance(res, type):
321+
res = asdict(res)
322+
323+
if not isinstance(res, dict):
324+
raise AssertionError("invalid type of tool response")
325+
320326
return types.CallToolResult(
321-
structuredContent=asdict(res),
327+
structuredContent=res, # pyright: ignore[reportUnknownArgumentType]
322328
content=[],
323329
)
324330
except BaseExceptionGroup as e:
@@ -354,6 +360,7 @@ def _input_schema(self, func: Callable[_P, _R]) -> dict[str, Any]:
354360

355361
return input_schema
356362

363+
# TODO: figure out how to handle custom classes as output type
357364
def _output_schema(self, func: Callable[_P, _R]) -> tuple[dict[str, Any], bool]:
358365
"""
359366
Generates a output schema for the provided func, if necessary wraps the

tests/integration/ai/test_agent_mcp_tools.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,27 @@
2525
from splunklib.ai import Agent
2626
from splunklib.ai.engines.langchain import LOCAL_TOOL_PREFIX
2727
from splunklib.ai.messages import (
28+
AIMessage,
2829
HumanMessage,
30+
ToolCall,
2931
ToolFailureResult,
3032
ToolMessage,
3133
ToolResult,
3234
)
35+
from splunklib.ai.middleware import (
36+
ModelMiddlewareHandler,
37+
ModelRequest,
38+
ModelResponse,
39+
model_middleware,
40+
)
3341
from splunklib.ai.tool_settings import (
3442
LocalToolSettings,
3543
RemoteToolSettings,
3644
ToolAllowlist,
3745
ToolSettings,
3846
)
3947
from splunklib.ai.tools import (
48+
ToolType,
4049
_get_splunk_username, # pyright: ignore[reportPrivateUsage]
4150
locate_app,
4251
)
@@ -589,6 +598,67 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]:
589598
response = result.final_message.content
590599
assert "31.5" in response, "Invalid LLM response"
591600

601+
@patch(
602+
"splunklib.ai.agent._testing_local_tools_path",
603+
os.path.join(os.path.dirname(__file__), "testdata", "temperature_as_dict.py"),
604+
)
605+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
606+
@pytest.mark.asyncio
607+
async def test_supports_plain_dicts_as_tool_outputs(self) -> None:
608+
"""Regression test for DVPL-13022"""
609+
pytest.importorskip("langchain_openai")
610+
611+
messages: list[AIMessage] = [
612+
AIMessage(
613+
content="",
614+
calls=[
615+
ToolCall(
616+
name="temperature",
617+
args={"city": "Krakow"},
618+
id="call_hSdIJSuUZOh2IiBsqfrzhA7d",
619+
type=ToolType.LOCAL,
620+
)
621+
],
622+
),
623+
AIMessage(content="The temperature in Krakow is 22°C.", calls=[]),
624+
]
625+
626+
responses = (m for m in messages)
627+
628+
@model_middleware
629+
async def middleware(
630+
req: ModelRequest, handler: ModelMiddlewareHandler
631+
) -> ModelResponse:
632+
return ModelResponse(message=next(responses))
633+
634+
async with Agent(
635+
model=(await self.model()),
636+
system_prompt="You must use the available tools to perform requested operations",
637+
service=self.service,
638+
tool_settings=ToolSettings(local=True, remote=None),
639+
middleware=[middleware],
640+
) as agent:
641+
result = await agent.invoke(
642+
[
643+
HumanMessage(
644+
content=(
645+
"What is the weather like today in Krakow? Use the provided tools to check the temperature."
646+
+ "Return a short response, containing the tool response."
647+
),
648+
)
649+
]
650+
)
651+
652+
tool_message = next(
653+
filter(lambda m: m.role == "tool", result.messages), None
654+
)
655+
assert isinstance(tool_message, ToolMessage), "Invalid tool message"
656+
assert tool_message, "No tool message found in response"
657+
assert tool_message.name == "temperature", "Invalid tool name"
658+
659+
response = result.final_message.content
660+
assert "22" in response, "Invalid LLM response"
661+
592662

593663
class TestHandlingToolNameCollision(AITestCase):
594664
@patch(

tests/integration/ai/test_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,21 @@ async def test_tool_hello(self):
118118
self.assertEqual(res.structuredContent, {"result": "Hello Stefan"})
119119

120120

121+
class TestTemperatureAsDictRegistry(TestRegistryTestCase):
122+
async def test_tool_temperature_returning_dict(self):
123+
async with self.connect("temperature_as_dict.py") as session:
124+
res = await session.call_tool(
125+
"temperature",
126+
arguments={"city": "Krakow"},
127+
meta={"splunk": {"service": self.serialized_service.model_dump()}},
128+
)
129+
self.assertEqual(res.isError, False)
130+
self.assertEqual(res.content, [])
131+
self.assertEqual(
132+
res.structuredContent, {"city": "Krakow", "temperature": 22}
133+
)
134+
135+
121136
@dataclass
122137
class Log:
123138
level: LoggingLevel
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Any
2+
3+
from splunklib.ai.registry import ToolContext, ToolRegistry
4+
5+
registry = ToolRegistry()
6+
7+
8+
@registry.tool(name="temperature", tags=["read"])
9+
def temperature(city: str, _ctx: ToolContext) -> dict[str, Any]:
10+
"""A simple tool that returns a temperature for the city."""
11+
12+
return {"city": city, "temperature": 22}
13+
14+
15+
registry.run()

tests/unit/ai/test_registry_unit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def structured_tool() -> Output:
6666
"type": "object",
6767
}
6868

69+
def test_output_non_wrapped_dict(self) -> None:
70+
r = ToolRegistry()
71+
72+
@r.tool()
73+
def structured_tool() -> dict[str, Any]:
74+
return {"some": "info"}
75+
76+
tool = r._tools[0]
77+
assert tool.name == "structured_tool"
78+
assert tool.inputSchema == {
79+
"properties": {},
80+
"type": "object",
81+
"additionalProperties": False,
82+
}
83+
assert tool.outputSchema == {
84+
"additionalProperties": True,
85+
"type": "object",
86+
}
87+
6988
def test_output_wrapped(self) -> None:
7089
r = ToolRegistry()
7190

0 commit comments

Comments
 (0)