Skip to content

Commit f36fd48

Browse files
Merge pull request #366 from UiPath/fix-escalation-tool-calls-extraction
fix: update escalation action to extract tool calls
2 parents b6b23c2 + 8649e3f commit f36fd48

5 files changed

Lines changed: 399 additions & 140 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.1.37"
3+
version = "0.1.38"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 154 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3+
import ast
34
import json
45
import re
5-
from typing import Any, Dict, Literal
6+
from typing import Any, Dict, Literal, cast
67

7-
from langchain_core.messages import AIMessage, ToolMessage
8+
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage
89
from langgraph.types import Command, interrupt
910
from uipath.platform.common import CreateEscalation
1011
from uipath.platform.guardrails import (
@@ -72,20 +73,48 @@ def action_node(
7273
async def _node(
7374
state: AgentGuardrailsGraphState,
7475
) -> Dict[str, Any] | Command[Any]:
75-
input = _extract_escalation_content(
76-
state, scope, execution_stage, guarded_component_name
77-
)
78-
escalation_field = _execution_stage_to_escalation_field(execution_stage)
76+
# Validate message count based on execution stage
77+
_validate_message_count(state, execution_stage)
7978

80-
data = {
79+
# Build base data dictionary with common fields
80+
data: Dict[str, Any] = {
8181
"GuardrailName": guardrail.name,
8282
"GuardrailDescription": guardrail.description,
8383
"Component": scope.name.lower(),
8484
"ExecutionStage": _execution_stage_to_string(execution_stage),
8585
"GuardrailResult": state.guardrail_validation_result,
86-
escalation_field: input,
8786
}
8887

88+
# Add stage-specific fields
89+
if execution_stage == ExecutionStage.PRE_EXECUTION:
90+
# PRE_EXECUTION: Only Inputs field from last message
91+
input_content = _extract_escalation_content(
92+
state.messages[-1],
93+
scope,
94+
execution_stage,
95+
guarded_component_name,
96+
)
97+
data["Inputs"] = input_content
98+
else: # POST_EXECUTION
99+
# Extract Inputs from second-to-last message using PRE_EXECUTION logic
100+
input_content = _extract_escalation_content(
101+
state.messages[-2],
102+
scope,
103+
ExecutionStage.PRE_EXECUTION,
104+
guarded_component_name,
105+
)
106+
107+
# Extract Outputs from last message using POST_EXECUTION logic
108+
output_content = _extract_escalation_content(
109+
state.messages[-1],
110+
scope,
111+
execution_stage,
112+
guarded_component_name,
113+
)
114+
115+
data["Inputs"] = input_content
116+
data["Outputs"] = output_content
117+
89118
escalation_result = interrupt(
90119
CreateEscalation(
91120
app_name=self.app_name,
@@ -114,6 +143,39 @@ async def _node(
114143
return node_name, _node
115144

116145

146+
def _validate_message_count(
147+
state: AgentGuardrailsGraphState,
148+
execution_stage: ExecutionStage,
149+
) -> None:
150+
"""Validate that state has the required number of messages for the execution stage.
151+
152+
Args:
153+
state: The current agent graph state.
154+
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
155+
156+
Raises:
157+
AgentTerminationException: If the state doesn't have enough messages.
158+
"""
159+
required_messages = 1 if execution_stage == ExecutionStage.PRE_EXECUTION else 2
160+
actual_messages = len(state.messages)
161+
162+
if actual_messages < required_messages:
163+
stage_name = (
164+
"PRE_EXECUTION"
165+
if execution_stage == ExecutionStage.PRE_EXECUTION
166+
else "POST_EXECUTION"
167+
)
168+
detail = f"{stage_name} requires at least {required_messages} message{'s' if required_messages > 1 else ''} in state, but found {actual_messages}."
169+
if execution_stage == ExecutionStage.POST_EXECUTION:
170+
detail += " Cannot extract Inputs from previous message."
171+
172+
raise AgentTerminationException(
173+
code=UiPathErrorCode.EXECUTION_ERROR,
174+
title=f"Invalid state for {stage_name}",
175+
detail=detail,
176+
)
177+
178+
117179
def _get_node_name(
118180
execution_stage: ExecutionStage, guardrail: BaseGuardrail, scope: GuardrailScope
119181
) -> str:
@@ -196,39 +258,54 @@ def _process_llm_escalation_response(
196258
if not reviewed_outputs_json:
197259
return {}
198260

199-
content_list = json.loads(reviewed_outputs_json)
200-
if not content_list:
261+
reviewed_tool_calls_list = json.loads(reviewed_outputs_json)
262+
if not reviewed_tool_calls_list:
201263
return {}
202264

265+
# Track if tool calls were successfully processed
266+
tool_calls_processed = False
267+
203268
# For AI messages, process tool calls if present
204269
if isinstance(last_message, AIMessage):
205270
ai_message: AIMessage = last_message
206-
content_index = 0
207271

208-
if ai_message.tool_calls:
272+
if ai_message.tool_calls and isinstance(reviewed_tool_calls_list, list):
209273
tool_calls = list(ai_message.tool_calls)
210-
for tool_call in tool_calls:
211-
args = tool_call["args"]
274+
275+
# Create a name-to-args mapping from reviewed tool call data
276+
reviewed_tool_calls_map = {}
277+
for reviewed_data in reviewed_tool_calls_list:
212278
if (
213-
isinstance(args, dict)
214-
and "content" in args
215-
and args["content"] is not None
279+
isinstance(reviewed_data, dict)
280+
and "name" in reviewed_data
281+
and "args" in reviewed_data
216282
):
217-
if content_index < len(content_list):
218-
updated_content = json.loads(
219-
content_list[content_index]
220-
)
221-
args["content"] = updated_content
222-
tool_call["args"] = args
223-
content_index += 1
224-
ai_message.tool_calls = tool_calls
225-
226-
if len(content_list) > content_index:
227-
ai_message.content = content_list[-1]
228-
else:
229-
# Fallback for other message types
230-
if content_list:
231-
last_message.content = content_list[-1]
283+
reviewed_tool_calls_map[reviewed_data["name"]] = (
284+
reviewed_data["args"]
285+
)
286+
287+
# Update tool calls with reviewed args by matching name
288+
if reviewed_tool_calls_map:
289+
for tool_call in tool_calls:
290+
tool_name = (
291+
tool_call.get("name")
292+
if isinstance(tool_call, dict)
293+
else getattr(tool_call, "name", None)
294+
)
295+
if tool_name and tool_name in reviewed_tool_calls_map:
296+
if isinstance(tool_call, dict):
297+
tool_call["args"] = reviewed_tool_calls_map[
298+
tool_name
299+
]
300+
else:
301+
tool_call.args = reviewed_tool_calls_map[tool_name]
302+
303+
ai_message.tool_calls = tool_calls
304+
tool_calls_processed = True
305+
306+
# Fallback: update message content if tool_calls weren't processed
307+
if not tool_calls_processed:
308+
last_message.content = reviewed_outputs_json
232309

233310
return Command(update={"messages": msgs})
234311
except Exception as e:
@@ -326,97 +403,80 @@ def _process_tool_escalation_response(
326403

327404

328405
def _extract_escalation_content(
329-
state: AgentGuardrailsGraphState,
406+
message: BaseMessage,
330407
scope: GuardrailScope,
331408
execution_stage: ExecutionStage,
332409
guarded_node_name: str,
333410
) -> str | list[str | Dict[str, Any]]:
334-
"""Extract escalation content from state based on guardrail scope and execution stage.
411+
"""Extract escalation content from a message based on guardrail scope and execution stage.
335412
336413
Args:
337-
state: The current agent graph state.
414+
message: The message to extract content from.
338415
scope: The guardrail scope (LLM/AGENT/TOOL).
339416
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
417+
guarded_node_name: Name of the guarded component.
340418
341419
Returns:
342420
str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
343421
For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
344-
345-
Raises:
346-
AgentTerminationException: If no messages are found in state.
347422
"""
348-
if not state.messages:
349-
raise AgentTerminationException(
350-
code=UiPathErrorCode.EXECUTION_ERROR,
351-
title="Invalid state message",
352-
detail="No message found into agent state",
353-
)
354-
355423
match scope:
356424
case GuardrailScope.LLM:
357-
return _extract_llm_escalation_content(state, execution_stage)
425+
return _extract_llm_escalation_content(message, execution_stage)
358426
case GuardrailScope.AGENT:
359-
return _extract_agent_escalation_content(state, execution_stage)
427+
return _extract_agent_escalation_content(message, execution_stage)
360428
case GuardrailScope.TOOL:
361429
return _extract_tool_escalation_content(
362-
state, execution_stage, guarded_node_name
430+
message, execution_stage, guarded_node_name
363431
)
364432

365433

366434
def _extract_llm_escalation_content(
367-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
435+
message: BaseMessage, execution_stage: ExecutionStage
368436
) -> str | list[str | Dict[str, Any]]:
369437
"""Extract escalation content for LLM scope guardrails.
370438
371439
Args:
372-
state: The current agent graph state.
440+
message: The message to extract content from.
373441
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
374442
375443
Returns:
376444
str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with message content or empty string.
377445
For PostExecution, returns JSON string (array) with tool call content and message content.
378446
Returns empty string if no content found.
379447
"""
380-
last_message = state.messages[-1]
381448
if execution_stage == ExecutionStage.PRE_EXECUTION:
382-
if isinstance(last_message, ToolMessage):
383-
return last_message.content
449+
if isinstance(message, ToolMessage):
450+
return message.content
384451

385-
content = get_message_content(last_message)
452+
content = get_message_content(cast(AnyMessage, message))
386453
return json.dumps(content) if content else ""
387454

388455
# For AI messages, process tool calls if present
389-
if isinstance(last_message, AIMessage):
390-
ai_message: AIMessage = last_message
391-
content_list: list[str] = []
456+
if isinstance(message, AIMessage):
457+
ai_message: AIMessage = message
392458

393459
if ai_message.tool_calls:
460+
content_list: list[Dict[str, Any]] = []
394461
for tool_call in ai_message.tool_calls:
395-
args = tool_call["args"]
396-
if (
397-
isinstance(args, dict)
398-
and "content" in args
399-
and args["content"] is not None
400-
):
401-
content_list.append(json.dumps(args["content"]))
402-
403-
message_content = get_message_content(last_message)
404-
if message_content:
405-
content_list.append(message_content)
406-
407-
return json.dumps(content_list)
462+
tool_call_data = {
463+
"name": tool_call.get("name"),
464+
"args": tool_call.get("args"),
465+
}
466+
content_list.append(tool_call_data)
467+
return json.dumps(content_list)
408468

409469
# Fallback for other message types
410-
return get_message_content(last_message)
470+
return get_message_content(cast(AnyMessage, message))
411471

412472

413473
def _extract_agent_escalation_content(
414-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
474+
message: BaseMessage, execution_stage: ExecutionStage
415475
) -> str | list[str | Dict[str, Any]]:
416476
"""Extract escalation content for AGENT scope guardrails.
417477
418478
Args:
419-
state: The current agent graph state.
479+
message: The message to extract content from.
420480
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
421481
422482
Returns:
@@ -426,12 +486,12 @@ def _extract_agent_escalation_content(
426486

427487

428488
def _extract_tool_escalation_content(
429-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage, tool_name: str
489+
message: BaseMessage, execution_stage: ExecutionStage, tool_name: str
430490
) -> str | list[str | Dict[str, Any]]:
431491
"""Extract escalation content for TOOL scope guardrails.
432492
433493
Args:
434-
state: The current agent graph state.
494+
message: The message to extract content from.
435495
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
436496
tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
437497
@@ -440,16 +500,31 @@ def _extract_tool_escalation_content(
440500
for the specified tool name, or empty string if not found. For PostExecution, returns string with
441501
tool message content, or empty string if message type doesn't match.
442502
"""
443-
last_message = state.messages[-1]
444503
if execution_stage == ExecutionStage.PRE_EXECUTION:
445-
args = _extract_tool_args_from_message(last_message, tool_name)
504+
args = _extract_tool_args_from_message(cast(AnyMessage, message), tool_name)
446505
if args:
447506
return json.dumps(args)
448507
return ""
449508
else:
450-
if not isinstance(last_message, ToolMessage):
509+
if not isinstance(message, ToolMessage):
451510
return ""
452-
return last_message.content
511+
content = message.content
512+
513+
# If content is already dict/list, serialize to JSON
514+
if isinstance(content, (dict, list)):
515+
return json.dumps(content)
516+
517+
# If content is a string that looks like a Python literal, convert to JSON
518+
if isinstance(content, str):
519+
try:
520+
# Try to parse as Python literal and convert to JSON
521+
parsed_content = ast.literal_eval(content)
522+
return json.dumps(parsed_content)
523+
except (ValueError, SyntaxError):
524+
# If parsing fails, return as-is
525+
pass
526+
527+
return content
453528

454529

455530
def _execution_stage_to_escalation_field(

src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _create_guardrails_subgraph(
117117
ExecutionStage.POST_EXECUTION,
118118
node_factory,
119119
END,
120-
inner_node,
120+
inner_name,
121121
)
122122
subgraph.add_edge(inner_name, first_post_exec_guardrail_node)
123123
else:

0 commit comments

Comments
 (0)