|
1 | 1 | """Tool node factory wiring directly to LangGraph's ToolNode.""" |
2 | 2 |
|
3 | 3 | from collections.abc import Sequence |
| 4 | +from inspect import signature |
| 5 | +from typing import Any, Awaitable, Callable, Literal |
4 | 6 |
|
| 7 | +from langchain_core.messages.ai import AIMessage |
| 8 | +from langchain_core.messages.tool import ToolCall, ToolMessage |
5 | 9 | from langchain_core.tools import BaseTool |
6 | | -from langgraph.prebuilt import ToolNode |
| 10 | +from langgraph._internal._runnable import RunnableCallable |
| 11 | +from langgraph.types import Command |
| 12 | +from pydantic import BaseModel |
7 | 13 |
|
| 14 | +# the type safety can be improved with generics |
| 15 | +ToolWrapperType = Callable[ |
| 16 | + [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None |
| 17 | +] |
| 18 | +AsyncToolWrapperType = Callable[ |
| 19 | + [BaseTool, ToolCall, Any], |
| 20 | + Awaitable[dict[str, Any] | Command[Any] | None], |
| 21 | +] |
| 22 | +OutputType = dict[Literal["messages"], list[ToolMessage]] | Command[Any] | None |
8 | 23 |
|
9 | | -def create_tool_node(tools: Sequence[BaseTool]) -> dict[str, ToolNode]: |
| 24 | + |
| 25 | +class UiPathToolNode(RunnableCallable): |
| 26 | + """ |
| 27 | + A ToolNode that can be used in a React agent graph. |
| 28 | + It extracts the tool call from the state messages and invokes the tool. |
| 29 | + It supports optional synchronous and asynchronous wrappers for custom processing. |
| 30 | + Generic over the state model. |
| 31 | + Args: |
| 32 | + tool: The tool to invoke. |
| 33 | + wrapper: An optional synchronous wrapper for custom processing. |
| 34 | + awrapper: An optional asynchronous wrapper for custom processing. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + A dict with ToolMessage or a Command. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + tool: BaseTool, |
| 43 | + wrapper: ToolWrapperType | None = None, |
| 44 | + awrapper: AsyncToolWrapperType | None = None, |
| 45 | + ): |
| 46 | + super().__init__(func=self._func, afunc=self._afunc, name=tool.name) |
| 47 | + self.tool = tool |
| 48 | + self.wrapper = wrapper |
| 49 | + self.awrapper = awrapper |
| 50 | + |
| 51 | + def _func(self, state: Any) -> OutputType: |
| 52 | + call = self._extract_tool_call(state) |
| 53 | + if call is None: |
| 54 | + return None |
| 55 | + if self.wrapper: |
| 56 | + filtered_state = self._filter_state(state, self.wrapper) |
| 57 | + result = self.wrapper(self.tool, call, filtered_state) |
| 58 | + else: |
| 59 | + result = self.tool.invoke(call["args"]) |
| 60 | + |
| 61 | + return self._process_result(call, result) |
| 62 | + |
| 63 | + async def _afunc(self, state: Any) -> OutputType: |
| 64 | + call = self._extract_tool_call(state) |
| 65 | + if call is None: |
| 66 | + return None |
| 67 | + if self.awrapper: |
| 68 | + filtered_state = self._filter_state(state, self.awrapper) |
| 69 | + result = await self.awrapper(self.tool, call, filtered_state) |
| 70 | + else: |
| 71 | + result = await self.tool.ainvoke(call["args"]) |
| 72 | + |
| 73 | + return self._process_result(call, result) |
| 74 | + |
| 75 | + def _extract_tool_call(self, state: Any) -> ToolCall | None: |
| 76 | + """Extract the tool call from the state messages.""" |
| 77 | + |
| 78 | + if not hasattr(state, "messages"): |
| 79 | + raise ValueError("State does not have messages key") |
| 80 | + |
| 81 | + last_message = state.messages[-1] |
| 82 | + if not isinstance(last_message, AIMessage): |
| 83 | + raise ValueError("Last message in message stack is not an AIMessage.") |
| 84 | + |
| 85 | + for tool_call in last_message.tool_calls: |
| 86 | + if tool_call["name"] == self.tool.name: |
| 87 | + return tool_call |
| 88 | + return None |
| 89 | + |
| 90 | + def _process_result( |
| 91 | + self, call: ToolCall, result: dict[str, Any] | Command[Any] | None |
| 92 | + ) -> OutputType: |
| 93 | + """Process the tool result into a message format or return a Command.""" |
| 94 | + if isinstance(result, Command): |
| 95 | + return result |
| 96 | + else: |
| 97 | + message = ToolMessage( |
| 98 | + content=str(result), name=call["name"], tool_call_id=call["id"] |
| 99 | + ) |
| 100 | + return {"messages": [message]} |
| 101 | + |
| 102 | + def _filter_state( |
| 103 | + self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType |
| 104 | + ) -> BaseModel: |
| 105 | + """Filter the state to the expected model type.""" |
| 106 | + model_type = list(signature(wrapper).parameters.values())[2].annotation |
| 107 | + if not issubclass(model_type, BaseModel): |
| 108 | + raise ValueError( |
| 109 | + "Wrapper state parameter must be a pydantic BaseModel subclass." |
| 110 | + ) |
| 111 | + return model_type.model_validate(state, from_attributes=True) |
| 112 | + |
| 113 | + |
| 114 | +class ToolWrapperMixin: |
| 115 | + wrapper: ToolWrapperType | None = None |
| 116 | + awrapper: AsyncToolWrapperType | None = None |
| 117 | + |
| 118 | + def set_tool_wrappers( |
| 119 | + self, |
| 120 | + wrapper: ToolWrapperType | None = None, |
| 121 | + awrapper: AsyncToolWrapperType | None = None, |
| 122 | + ) -> None: |
| 123 | + """Define wrappers for the tool execution.""" |
| 124 | + self.wrapper = wrapper |
| 125 | + self.awrapper = awrapper |
| 126 | + |
| 127 | + |
| 128 | +def create_tool_node(tools: Sequence[BaseTool]) -> dict[str, UiPathToolNode]: |
10 | 129 | """Create individual ToolNode for each tool. |
11 | 130 |
|
12 | 131 | Args: |
13 | 132 | tools: Sequence of tools to create nodes for. |
| 133 | + agentState: The type of the agent state model. |
14 | 134 |
|
15 | 135 | Returns: |
16 | | - Dict mapping tool.name -> ToolNode([tool]). |
| 136 | + Dict mapping tool.name -> ReactToolNode([tool]). |
17 | 137 | Each tool gets its own dedicated node for middleware composition. |
18 | 138 |
|
19 | 139 | Note: |
20 | 140 | handle_tool_errors=False delegates error handling to LangGraph's error boundary. |
21 | 141 | """ |
22 | | - return {tool.name: ToolNode([tool], handle_tool_errors=False) for tool in tools} |
| 142 | + dict_mapping: dict[str, UiPathToolNode] = {} |
| 143 | + for tool in tools: |
| 144 | + if isinstance(tool, ToolWrapperMixin): |
| 145 | + dict_mapping[tool.name] = UiPathToolNode( |
| 146 | + tool, wrapper=tool.wrapper, awrapper=tool.awrapper |
| 147 | + ) |
| 148 | + else: |
| 149 | + dict_mapping[tool.name] = UiPathToolNode(tool, wrapper=None, awrapper=None) |
| 150 | + return dict_mapping |
0 commit comments