Skip to content

Commit d4f204a

Browse files
Merge branch 'main' into fix-escalation-tool-calls-extraction
2 parents a23ecc3 + 839f74d commit d4f204a

12 files changed

Lines changed: 688 additions & 67 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.1.34"
3+
version = "0.1.35"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from functools import partial
2-
from typing import Any, Callable, Sequence
2+
from typing import Any, Callable, Mapping, Sequence
33

4+
from langgraph._internal._runnable import RunnableCallable
45
from langgraph.constants import END, START
56
from langgraph.graph import StateGraph
6-
from langgraph.prebuilt import ToolNode
77
from uipath.platform.guardrails import (
88
BaseGuardrail,
99
BuiltInValidatorGuardrail,
@@ -221,13 +221,13 @@ def create_llm_guardrails_subgraph(
221221

222222

223223
def create_tools_guardrails_subgraph(
224-
tool_nodes: dict[str, ToolNode],
224+
tool_nodes: Mapping[str, RunnableCallable],
225225
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
226-
) -> dict[str, ToolNode]:
226+
) -> dict[str, RunnableCallable]:
227227
"""Create tool nodes with guardrails.
228228
Args:
229229
"""
230-
result: dict[str, ToolNode] = {}
230+
result: dict[str, RunnableCallable] = {}
231231
for tool_name, tool_node in tool_nodes.items():
232232
subgraph = create_tool_guardrails_subgraph(
233233
(tool_name, tool_node),

src/uipath_langchain/agent/tools/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .tool_factory import (
88
create_tools_from_resources,
99
)
10-
from .tool_node import create_tool_node
10+
from .tool_node import ToolWrapperMixin, UiPathToolNode, create_tool_node
1111

1212
__all__ = [
1313
"create_tools_from_resources",
@@ -16,4 +16,6 @@
1616
"create_process_tool",
1717
"create_integration_tool",
1818
"create_mcp_tools",
19+
"UiPathToolNode",
20+
"ToolWrapperMixin",
1921
]

src/uipath_langchain/agent/tools/integration_tool.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
from typing import Any
55

66
from jsonschema_pydantic_converter import transform as create_model
7-
from langchain.tools import ToolRuntime
87
from langchain_core.tools import StructuredTool
98
from uipath.agent.models.agent import AgentIntegrationToolResourceConfig
109
from uipath.eval.mocks import mockable
1110
from uipath.platform import UiPath
1211
from uipath.platform.connections import ActivityMetadata, ActivityParameterLocationInfo
1312

14-
from .static_args import handle_static_args
13+
from uipath_langchain.agent.tools.tool_node import ToolWrapperMixin
14+
from uipath_langchain.agent.wrappers.static_args_wrapper import get_static_args_wrapper
15+
1516
from .structured_tool_with_output_type import StructuredToolWithOutputType
16-
from .utils import sanitize_tool_name
17+
from .utils import sanitize_dict_for_serialization, sanitize_tool_name
18+
19+
20+
class StructuredToolWithStaticArgs(StructuredToolWithOutputType, ToolWrapperMixin):
21+
pass
1722

1823

1924
def remove_asterisk_from_properties(fields: dict[str, Any]) -> dict[str, Any]:
@@ -150,32 +155,27 @@ def create_integration_tool(
150155
input_schema=input_model.model_json_schema(),
151156
output_schema=output_model.model_json_schema(),
152157
)
153-
async def integration_tool_fn(runtime: ToolRuntime, **kwargs: Any):
158+
async def integration_tool_fn(**kwargs: Any):
154159
try:
155-
# we manually validating here and not passing input_model to StructuredTool
156-
# because langchain itself will block their own injected arguments (like runtime) if the model is strict
157-
val_args = input_model.model_validate(kwargs)
158-
args = handle_static_args(
159-
resource=resource,
160-
runtime=runtime,
161-
input_args=val_args.model_dump(),
162-
)
163160
result = await sdk.connections.invoke_activity_async(
164161
activity_metadata=activity_metadata,
165162
connection_id=connection_id,
166-
activity_input=args,
163+
activity_input=sanitize_dict_for_serialization(kwargs),
167164
)
168165
except Exception:
169166
raise
170167

171168
return result
172169

173-
tool = StructuredToolWithOutputType(
170+
wrapper = get_static_args_wrapper(resource)
171+
172+
tool = StructuredToolWithStaticArgs(
174173
name=tool_name,
175174
description=resource.description,
176-
args_schema=resource.input_schema,
175+
args_schema=input_model,
177176
coroutine=integration_tool_fn,
178177
output_type=output_model,
179178
)
179+
tool.set_tool_wrappers(awrapper=wrapper)
180180

181181
return tool
Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
"""Handles static arguments for tool calls."""
22

3-
from typing import Any, Dict, List
3+
from typing import Any
44

55
from jsonpath_ng import parse # type: ignore[import-untyped]
6-
from langchain.tools import ToolRuntime
6+
from pydantic import BaseModel
77
from uipath.agent.models.agent import (
88
AgentIntegrationToolParameter,
99
AgentIntegrationToolResourceConfig,
1010
BaseAgentResourceConfig,
1111
)
1212

13+
from .utils import sanitize_dict_for_serialization
14+
1315

1416
def resolve_static_args(
1517
resource: BaseAgentResourceConfig,
16-
agent_input: Dict[str, Any],
17-
) -> Dict[str, Any]:
18+
agent_input: dict[str, Any],
19+
) -> dict[str, Any]:
1820
"""Resolves static arguments for a given resource with a given input.
1921
2022
Args:
2123
resource: The agent resource configuration.
22-
input: Othe input arguments passed to the agent.
24+
input: The input arguments passed to the agent.
2325
2426
Returns:
2527
A dictionary of expanded arguments to be used in the tool call.
@@ -35,9 +37,9 @@ def resolve_static_args(
3537

3638

3739
def resolve_integration_static_args(
38-
parameters: List[AgentIntegrationToolParameter],
39-
agent_input: Dict[str, Any],
40-
) -> Dict[str, Any]:
40+
parameters: list[AgentIntegrationToolParameter],
41+
agent_input: dict[str, Any],
42+
) -> dict[str, Any]:
4143
"""Resolves static arguments for an integration tool resource.
4244
4345
Args:
@@ -48,7 +50,7 @@ def resolve_integration_static_args(
4850
A dictionary of expanded static arguments for the integration tool.
4951
"""
5052

51-
static_args: Dict[str, Any] = {}
53+
static_args: dict[str, Any] = {}
5254
for param in parameters:
5355
value = None
5456

@@ -75,34 +77,10 @@ def resolve_integration_static_args(
7577
return static_args
7678

7779

78-
def sanitize_for_serialization(args: Dict[str, Any]) -> Dict[str, Any]:
79-
"""Convert Pydantic models in args to dicts."""
80-
converted_args: Dict[str, Any] = {}
81-
for key, value in args.items():
82-
# handle Pydantic model
83-
if hasattr(value, "model_dump"):
84-
converted_args[key] = value.model_dump()
85-
86-
elif isinstance(value, list):
87-
# handle list of Pydantic models
88-
converted_list = []
89-
for item in value:
90-
if hasattr(item, "model_dump"):
91-
converted_list.append(item.model_dump())
92-
else:
93-
converted_list.append(item)
94-
converted_args[key] = converted_list
95-
96-
# handle regular value or unexpected type
97-
else:
98-
converted_args[key] = value
99-
return converted_args
100-
101-
10280
def apply_static_args(
103-
static_args: Dict[str, Any],
104-
kwargs: Dict[str, Any],
105-
) -> Dict[str, Any]:
81+
static_args: dict[str, Any],
82+
kwargs: dict[str, Any],
83+
) -> dict[str, Any]:
10684
"""Applies static arguments to the given input arguments.
10785
10886
Args:
@@ -113,7 +91,7 @@ def apply_static_args(
11391
Merged input arguments with static arguments applied.
11492
"""
11593

116-
sanitized_args = sanitize_for_serialization(kwargs)
94+
sanitized_args = sanitize_dict_for_serialization(kwargs)
11795
for json_path, value in static_args.items():
11896
expr = parse(json_path)
11997
expr.update_or_create(sanitized_args, value)
@@ -122,8 +100,8 @@ def apply_static_args(
122100

123101

124102
def handle_static_args(
125-
resource: BaseAgentResourceConfig, runtime: ToolRuntime, input_args: Dict[str, Any]
126-
) -> Dict[str, Any]:
103+
resource: BaseAgentResourceConfig, state: BaseModel, input_args: dict[str, Any]
104+
) -> dict[str, Any]:
127105
"""Resolves and applies static arguments for a tool call.
128106
Args:
129107
resource: The agent resource configuration.
@@ -133,6 +111,6 @@ def handle_static_args(
133111
A dictionary of input arguments with static arguments applied.
134112
"""
135113

136-
static_args = resolve_static_args(resource, dict(runtime.state))
114+
static_args = resolve_static_args(resource, dict(state))
137115
merged_args = apply_static_args(static_args, input_args)
138116
return merged_args
Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,150 @@
11
"""Tool node factory wiring directly to LangGraph's ToolNode."""
22

33
from collections.abc import Sequence
4+
from inspect import signature
5+
from typing import Any, Awaitable, Callable, Literal
46

7+
from langchain_core.messages.ai import AIMessage
8+
from langchain_core.messages.tool import ToolCall, ToolMessage
59
from langchain_core.tools import BaseTool
6-
from langgraph.prebuilt import ToolNode
10+
from langgraph._internal._runnable import RunnableCallable
11+
from langgraph.types import Command
12+
from pydantic import BaseModel
713

14+
# the type safety can be improved with generics
15+
ToolWrapperType = Callable[
16+
[BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None
17+
]
18+
AsyncToolWrapperType = Callable[
19+
[BaseTool, ToolCall, Any],
20+
Awaitable[dict[str, Any] | Command[Any] | None],
21+
]
22+
OutputType = dict[Literal["messages"], list[ToolMessage]] | Command[Any] | None
823

9-
def create_tool_node(tools: Sequence[BaseTool]) -> dict[str, ToolNode]:
24+
25+
class UiPathToolNode(RunnableCallable):
26+
"""
27+
A ToolNode that can be used in a React agent graph.
28+
It extracts the tool call from the state messages and invokes the tool.
29+
It supports optional synchronous and asynchronous wrappers for custom processing.
30+
Generic over the state model.
31+
Args:
32+
tool: The tool to invoke.
33+
wrapper: An optional synchronous wrapper for custom processing.
34+
awrapper: An optional asynchronous wrapper for custom processing.
35+
36+
Returns:
37+
A dict with ToolMessage or a Command.
38+
"""
39+
40+
def __init__(
41+
self,
42+
tool: BaseTool,
43+
wrapper: ToolWrapperType | None = None,
44+
awrapper: AsyncToolWrapperType | None = None,
45+
):
46+
super().__init__(func=self._func, afunc=self._afunc, name=tool.name)
47+
self.tool = tool
48+
self.wrapper = wrapper
49+
self.awrapper = awrapper
50+
51+
def _func(self, state: Any) -> OutputType:
52+
call = self._extract_tool_call(state)
53+
if call is None:
54+
return None
55+
if self.wrapper:
56+
filtered_state = self._filter_state(state, self.wrapper)
57+
result = self.wrapper(self.tool, call, filtered_state)
58+
else:
59+
result = self.tool.invoke(call["args"])
60+
61+
return self._process_result(call, result)
62+
63+
async def _afunc(self, state: Any) -> OutputType:
64+
call = self._extract_tool_call(state)
65+
if call is None:
66+
return None
67+
if self.awrapper:
68+
filtered_state = self._filter_state(state, self.awrapper)
69+
result = await self.awrapper(self.tool, call, filtered_state)
70+
else:
71+
result = await self.tool.ainvoke(call["args"])
72+
73+
return self._process_result(call, result)
74+
75+
def _extract_tool_call(self, state: Any) -> ToolCall | None:
76+
"""Extract the tool call from the state messages."""
77+
78+
if not hasattr(state, "messages"):
79+
raise ValueError("State does not have messages key")
80+
81+
last_message = state.messages[-1]
82+
if not isinstance(last_message, AIMessage):
83+
raise ValueError("Last message in message stack is not an AIMessage.")
84+
85+
for tool_call in last_message.tool_calls:
86+
if tool_call["name"] == self.tool.name:
87+
return tool_call
88+
return None
89+
90+
def _process_result(
91+
self, call: ToolCall, result: dict[str, Any] | Command[Any] | None
92+
) -> OutputType:
93+
"""Process the tool result into a message format or return a Command."""
94+
if isinstance(result, Command):
95+
return result
96+
else:
97+
message = ToolMessage(
98+
content=str(result), name=call["name"], tool_call_id=call["id"]
99+
)
100+
return {"messages": [message]}
101+
102+
def _filter_state(
103+
self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType
104+
) -> BaseModel:
105+
"""Filter the state to the expected model type."""
106+
model_type = list(signature(wrapper).parameters.values())[2].annotation
107+
if not issubclass(model_type, BaseModel):
108+
raise ValueError(
109+
"Wrapper state parameter must be a pydantic BaseModel subclass."
110+
)
111+
return model_type.model_validate(state, from_attributes=True)
112+
113+
114+
class ToolWrapperMixin:
115+
wrapper: ToolWrapperType | None = None
116+
awrapper: AsyncToolWrapperType | None = None
117+
118+
def set_tool_wrappers(
119+
self,
120+
wrapper: ToolWrapperType | None = None,
121+
awrapper: AsyncToolWrapperType | None = None,
122+
) -> None:
123+
"""Define wrappers for the tool execution."""
124+
self.wrapper = wrapper
125+
self.awrapper = awrapper
126+
127+
128+
def create_tool_node(tools: Sequence[BaseTool]) -> dict[str, UiPathToolNode]:
10129
"""Create individual ToolNode for each tool.
11130
12131
Args:
13132
tools: Sequence of tools to create nodes for.
133+
agentState: The type of the agent state model.
14134
15135
Returns:
16-
Dict mapping tool.name -> ToolNode([tool]).
136+
Dict mapping tool.name -> ReactToolNode([tool]).
17137
Each tool gets its own dedicated node for middleware composition.
18138
19139
Note:
20140
handle_tool_errors=False delegates error handling to LangGraph's error boundary.
21141
"""
22-
return {tool.name: ToolNode([tool], handle_tool_errors=False) for tool in tools}
142+
dict_mapping: dict[str, UiPathToolNode] = {}
143+
for tool in tools:
144+
if isinstance(tool, ToolWrapperMixin):
145+
dict_mapping[tool.name] = UiPathToolNode(
146+
tool, wrapper=tool.wrapper, awrapper=tool.awrapper
147+
)
148+
else:
149+
dict_mapping[tool.name] = UiPathToolNode(tool, wrapper=None, awrapper=None)
150+
return dict_mapping

0 commit comments

Comments
 (0)