|
25 | 25 | from splunklib.ai import Agent |
26 | 26 | from splunklib.ai.engines.langchain import LOCAL_TOOL_PREFIX |
27 | 27 | from splunklib.ai.messages import ( |
| 28 | + AIMessage, |
28 | 29 | HumanMessage, |
| 30 | + ToolCall, |
29 | 31 | ToolFailureResult, |
30 | 32 | ToolMessage, |
31 | 33 | ToolResult, |
32 | 34 | ) |
| 35 | +from splunklib.ai.middleware import ( |
| 36 | + ModelMiddlewareHandler, |
| 37 | + ModelRequest, |
| 38 | + ModelResponse, |
| 39 | + model_middleware, |
| 40 | +) |
33 | 41 | from splunklib.ai.tool_settings import ( |
34 | 42 | LocalToolSettings, |
35 | 43 | RemoteToolSettings, |
36 | 44 | ToolAllowlist, |
37 | 45 | ToolSettings, |
38 | 46 | ) |
39 | 47 | from splunklib.ai.tools import ( |
| 48 | + ToolType, |
40 | 49 | _get_splunk_username, # pyright: ignore[reportPrivateUsage] |
41 | 50 | locate_app, |
42 | 51 | ) |
@@ -589,6 +598,67 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: |
589 | 598 | response = result.final_message.content |
590 | 599 | assert "31.5" in response, "Invalid LLM response" |
591 | 600 |
|
| 601 | + @patch( |
| 602 | + "splunklib.ai.agent._testing_local_tools_path", |
| 603 | + os.path.join(os.path.dirname(__file__), "testdata", "temperature_as_dict.py"), |
| 604 | + ) |
| 605 | + @patch("splunklib.ai.agent._testing_app_id", "app_id") |
| 606 | + @pytest.mark.asyncio |
| 607 | + async def test_supports_plain_dicts_as_tool_outputs(self) -> None: |
| 608 | + """Regression test for DVPL-13022""" |
| 609 | + pytest.importorskip("langchain_openai") |
| 610 | + |
| 611 | + messages: list[AIMessage] = [ |
| 612 | + AIMessage( |
| 613 | + content="", |
| 614 | + calls=[ |
| 615 | + ToolCall( |
| 616 | + name="temperature", |
| 617 | + args={"city": "Krakow"}, |
| 618 | + id="call_hSdIJSuUZOh2IiBsqfrzhA7d", |
| 619 | + type=ToolType.LOCAL, |
| 620 | + ) |
| 621 | + ], |
| 622 | + ), |
| 623 | + AIMessage(content="The temperature in Krakow is 22°C.", calls=[]), |
| 624 | + ] |
| 625 | + |
| 626 | + responses = (m for m in messages) |
| 627 | + |
| 628 | + @model_middleware |
| 629 | + async def middleware( |
| 630 | + req: ModelRequest, handler: ModelMiddlewareHandler |
| 631 | + ) -> ModelResponse: |
| 632 | + return ModelResponse(message=next(responses)) |
| 633 | + |
| 634 | + async with Agent( |
| 635 | + model=(await self.model()), |
| 636 | + system_prompt="You must use the available tools to perform requested operations", |
| 637 | + service=self.service, |
| 638 | + tool_settings=ToolSettings(local=True, remote=None), |
| 639 | + middleware=[middleware], |
| 640 | + ) as agent: |
| 641 | + result = await agent.invoke( |
| 642 | + [ |
| 643 | + HumanMessage( |
| 644 | + content=( |
| 645 | + "What is the weather like today in Krakow? Use the provided tools to check the temperature." |
| 646 | + + "Return a short response, containing the tool response." |
| 647 | + ), |
| 648 | + ) |
| 649 | + ] |
| 650 | + ) |
| 651 | + |
| 652 | + tool_message = next( |
| 653 | + filter(lambda m: m.role == "tool", result.messages), None |
| 654 | + ) |
| 655 | + assert isinstance(tool_message, ToolMessage), "Invalid tool message" |
| 656 | + assert tool_message, "No tool message found in response" |
| 657 | + assert tool_message.name == "temperature", "Invalid tool name" |
| 658 | + |
| 659 | + response = result.final_message.content |
| 660 | + assert "22" in response, "Invalid LLM response" |
| 661 | + |
592 | 662 |
|
593 | 663 | class TestHandlingToolNameCollision(AITestCase): |
594 | 664 | @patch( |
|
0 commit comments