Skip to content

Commit a23ecc3

Browse files
fix: update escalation action to extract tool calls
1 parent dd7c1a3 commit a23ecc3

2 files changed

Lines changed: 61 additions & 56 deletions

File tree

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -196,39 +196,54 @@ def _process_llm_escalation_response(
196196
if not reviewed_outputs_json:
197197
return {}
198198

199-
content_list = json.loads(reviewed_outputs_json)
200-
if not content_list:
199+
reviewed_tool_calls_list = json.loads(reviewed_outputs_json)
200+
if not reviewed_tool_calls_list:
201201
return {}
202202

203+
# Track if tool calls were successfully processed
204+
tool_calls_processed = False
205+
203206
# For AI messages, process tool calls if present
204207
if isinstance(last_message, AIMessage):
205208
ai_message: AIMessage = last_message
206-
content_index = 0
207209

208-
if ai_message.tool_calls:
210+
if ai_message.tool_calls and isinstance(reviewed_tool_calls_list, list):
209211
tool_calls = list(ai_message.tool_calls)
210-
for tool_call in tool_calls:
211-
args = tool_call["args"]
212+
213+
# Create a name-to-args mapping from reviewed tool call data
214+
reviewed_tool_calls_map = {}
215+
for reviewed_data in reviewed_tool_calls_list:
212216
if (
213-
isinstance(args, dict)
214-
and "content" in args
215-
and args["content"] is not None
217+
isinstance(reviewed_data, dict)
218+
and "name" in reviewed_data
219+
and "args" in reviewed_data
216220
):
217-
if content_index < len(content_list):
218-
updated_content = json.loads(
219-
content_list[content_index]
220-
)
221-
args["content"] = updated_content
222-
tool_call["args"] = args
223-
content_index += 1
224-
ai_message.tool_calls = tool_calls
225-
226-
if len(content_list) > content_index:
227-
ai_message.content = content_list[-1]
228-
else:
229-
# Fallback for other message types
230-
if content_list:
231-
last_message.content = content_list[-1]
221+
reviewed_tool_calls_map[reviewed_data["name"]] = (
222+
reviewed_data["args"]
223+
)
224+
225+
# Update tool calls with reviewed args by matching name
226+
if reviewed_tool_calls_map:
227+
for tool_call in tool_calls:
228+
tool_name = (
229+
tool_call.get("name")
230+
if isinstance(tool_call, dict)
231+
else getattr(tool_call, "name", None)
232+
)
233+
if tool_name and tool_name in reviewed_tool_calls_map:
234+
if isinstance(tool_call, dict):
235+
tool_call["args"] = reviewed_tool_calls_map[
236+
tool_name
237+
]
238+
else:
239+
tool_call.args = reviewed_tool_calls_map[tool_name]
240+
241+
ai_message.tool_calls = tool_calls
242+
tool_calls_processed = True
243+
244+
# Fallback: update message content if tool_calls weren't processed
245+
if not tool_calls_processed:
246+
last_message.content = reviewed_outputs_json
232247

233248
return Command(update={"messages": msgs})
234249
except Exception as e:
@@ -388,23 +403,16 @@ def _extract_llm_escalation_content(
388403
# For AI messages, process tool calls if present
389404
if isinstance(last_message, AIMessage):
390405
ai_message: AIMessage = last_message
391-
content_list: list[str] = []
392406

393407
if ai_message.tool_calls:
408+
content_list: list[Dict[str, Any]] = []
394409
for tool_call in ai_message.tool_calls:
395-
args = tool_call["args"]
396-
if (
397-
isinstance(args, dict)
398-
and "content" in args
399-
and args["content"] is not None
400-
):
401-
content_list.append(json.dumps(args["content"]))
402-
403-
message_content = get_message_content(last_message)
404-
if message_content:
405-
content_list.append(message_content)
406-
407-
return json.dumps(content_list)
410+
tool_call_data = {
411+
"name": tool_call.get("name"),
412+
"args": tool_call.get("args"),
413+
}
414+
content_list.append(tool_call_data)
415+
return json.dumps(content_list)
408416

409417
# Fallback for other message types
410418
return get_message_content(last_message)

tests/agent/guardrails/actions/test_escalate_action.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def test_node_post_execution_tool_field(self, mock_interrupt):
242242

243243
# Verify ToolOutputs is used for PostExecution
244244
call_args = mock_interrupt.call_args[0][0]
245-
assert call_args.data["Outputs"] == '["Test response"]'
245+
assert call_args.data["Outputs"] == "Test response"
246246
assert "Inputs" not in call_args.data
247247

248248
@pytest.mark.asyncio
@@ -287,12 +287,13 @@ async def test_post_execution_ai_message_with_tool_calls_extraction(
287287

288288
await node(state)
289289

290-
# Verify interrupt was called with tool calls and content in ToolOutputs
290+
# Verify interrupt was called with tool calls (name and args) in ToolOutputs
291291
call_args = mock_interrupt.call_args[0][0]
292292
tool_outputs = call_args.data["Outputs"]
293293
parsed = json.loads(tool_outputs)
294-
assert len(parsed) == 2 # Tool call content + message content
295-
assert parsed[1] == "AI response"
294+
assert len(parsed) == 1 # Tool call data with name and args
295+
assert parsed[0]["name"] == "test_tool"
296+
assert parsed[0]["args"] == {"content": {"input": "test"}}
296297

297298
@pytest.mark.asyncio
298299
@patch("uipath_langchain.agent.guardrails.actions.escalate_action.interrupt")
@@ -367,10 +368,10 @@ async def test_post_execution_human_message_with_reviewed_outputs(
367368

368369
result = await node(state)
369370

370-
# Verify HumanMessage content was updated (ignores tool calls)
371+
# Verify HumanMessage content was updated with fallback (raw JSON string)
371372
assert isinstance(result, Command)
372373
assert result.update is not None
373-
assert result.update["messages"][0].content == "Updated content"
374+
assert result.update["messages"][0].content == json.dumps(reviewed_content)
374375

375376
@pytest.mark.asyncio
376377
@patch("uipath_langchain.agent.guardrails.actions.escalate_action.interrupt")
@@ -388,12 +389,8 @@ async def test_post_execution_ai_message_with_reviewed_outputs_and_tool_calls(
388389
guardrail.name = "Test Guardrail"
389390
guardrail.description = "Test description"
390391

391-
reviewed_tool_content = {"updated": "tool_content"}
392-
reviewed_message_content = "Updated message"
393-
reviewed_outputs = [
394-
json.dumps(reviewed_tool_content),
395-
reviewed_message_content,
396-
]
392+
reviewed_tool_args = {"updated": "tool_content"}
393+
reviewed_outputs = [{"name": "test_tool", "args": reviewed_tool_args}]
397394
mock_escalation_result = MagicMock()
398395
mock_escalation_result.action = "Approve"
399396
mock_escalation_result.data = {"ReviewedOutputs": json.dumps(reviewed_outputs)}
@@ -420,12 +417,11 @@ async def test_post_execution_ai_message_with_reviewed_outputs_and_tool_calls(
420417

421418
result = await node(state)
422419

423-
# Verify tool calls and message content were updated
420+
# Verify tool calls args were updated by matching name
424421
assert isinstance(result, Command)
425422
assert result.update is not None
426423
updated_message = result.update["messages"][0]
427-
assert updated_message.tool_calls[0]["args"]["content"] == reviewed_tool_content
428-
assert updated_message.content == reviewed_message_content
424+
assert updated_message.tool_calls[0]["args"] == reviewed_tool_args
429425

430426
@pytest.mark.asyncio
431427
@patch("uipath_langchain.agent.guardrails.actions.escalate_action.interrupt")
@@ -1206,7 +1202,7 @@ async def test_extract_llm_content_pre_execution_empty_content(self):
12061202

12071203
@pytest.mark.asyncio
12081204
async def test_extract_llm_content_post_execution_tool_calls_no_content_field(self):
1209-
"""Extract LLM content PostExecution: tool calls without content field are skipped."""
1205+
"""Extract LLM content PostExecution: extracts all tool calls with name and args."""
12101206
from uipath_langchain.agent.guardrails.actions.escalate_action import (
12111207
_extract_llm_escalation_content,
12121208
)
@@ -1227,9 +1223,10 @@ async def test_extract_llm_content_post_execution_tool_calls_no_content_field(se
12271223

12281224
assert isinstance(result, str)
12291225
parsed = json.loads(result)
1230-
# Should only contain message content, not tool call content
1226+
# Should extract tool call data with name and args
12311227
assert len(parsed) == 1
1232-
assert parsed[0] == "Response"
1228+
assert parsed[0]["name"] == "tool_without_content"
1229+
assert parsed[0]["args"] == {"param": "value"}
12331230

12341231
@pytest.mark.asyncio
12351232
async def test_extract_escalation_content_empty_messages_raises_exception(self):

0 commit comments

Comments
 (0)