11from __future__ import annotations
22
3+ import ast
34import json
45import re
5- from typing import Any , Dict , Literal
6+ from typing import Any , Dict , Literal , cast
67
7- from langchain_core .messages import AIMessage , ToolMessage
8+ from langchain_core .messages import AIMessage , AnyMessage , BaseMessage , ToolMessage
89from langgraph .types import Command , interrupt
910from uipath .platform .common import CreateEscalation
1011from uipath .platform .guardrails import (
@@ -72,20 +73,48 @@ def action_node(
7273 async def _node (
7374 state : AgentGuardrailsGraphState ,
7475 ) -> Dict [str , Any ] | Command [Any ]:
75- input = _extract_escalation_content (
76- state , scope , execution_stage , guarded_component_name
77- )
78- escalation_field = _execution_stage_to_escalation_field (execution_stage )
76+ # Validate message count based on execution stage
77+ _validate_message_count (state , execution_stage )
7978
80- data = {
79+ # Build base data dictionary with common fields
80+ data : Dict [str , Any ] = {
8181 "GuardrailName" : guardrail .name ,
8282 "GuardrailDescription" : guardrail .description ,
8383 "Component" : scope .name .lower (),
8484 "ExecutionStage" : _execution_stage_to_string (execution_stage ),
8585 "GuardrailResult" : state .guardrail_validation_result ,
86- escalation_field : input ,
8786 }
8887
88+ # Add stage-specific fields
89+ if execution_stage == ExecutionStage .PRE_EXECUTION :
90+ # PRE_EXECUTION: Only Inputs field from last message
91+ input_content = _extract_escalation_content (
92+ state .messages [- 1 ],
93+ scope ,
94+ execution_stage ,
95+ guarded_component_name ,
96+ )
97+ data ["Inputs" ] = input_content
98+ else : # POST_EXECUTION
99+ # Extract Inputs from second-to-last message using PRE_EXECUTION logic
100+ input_content = _extract_escalation_content (
101+ state .messages [- 2 ],
102+ scope ,
103+ ExecutionStage .PRE_EXECUTION ,
104+ guarded_component_name ,
105+ )
106+
107+ # Extract Outputs from last message using POST_EXECUTION logic
108+ output_content = _extract_escalation_content (
109+ state .messages [- 1 ],
110+ scope ,
111+ execution_stage ,
112+ guarded_component_name ,
113+ )
114+
115+ data ["Inputs" ] = input_content
116+ data ["Outputs" ] = output_content
117+
89118 escalation_result = interrupt (
90119 CreateEscalation (
91120 app_name = self .app_name ,
@@ -114,6 +143,39 @@ async def _node(
114143 return node_name , _node
115144
116145
146+ def _validate_message_count (
147+ state : AgentGuardrailsGraphState ,
148+ execution_stage : ExecutionStage ,
149+ ) -> None :
150+ """Validate that state has the required number of messages for the execution stage.
151+
152+ Args:
153+ state: The current agent graph state.
154+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
155+
156+ Raises:
157+ AgentTerminationException: If the state doesn't have enough messages.
158+ """
159+ required_messages = 1 if execution_stage == ExecutionStage .PRE_EXECUTION else 2
160+ actual_messages = len (state .messages )
161+
162+ if actual_messages < required_messages :
163+ stage_name = (
164+ "PRE_EXECUTION"
165+ if execution_stage == ExecutionStage .PRE_EXECUTION
166+ else "POST_EXECUTION"
167+ )
168+ detail = f"{ stage_name } requires at least { required_messages } message{ 's' if required_messages > 1 else '' } in state, but found { actual_messages } ."
169+ if execution_stage == ExecutionStage .POST_EXECUTION :
170+ detail += " Cannot extract Inputs from previous message."
171+
172+ raise AgentTerminationException (
173+ code = UiPathErrorCode .EXECUTION_ERROR ,
174+ title = f"Invalid state for { stage_name } " ,
175+ detail = detail ,
176+ )
177+
178+
117179def _get_node_name (
118180 execution_stage : ExecutionStage , guardrail : BaseGuardrail , scope : GuardrailScope
119181) -> str :
@@ -196,39 +258,54 @@ def _process_llm_escalation_response(
196258 if not reviewed_outputs_json :
197259 return {}
198260
199- content_list = json .loads (reviewed_outputs_json )
200- if not content_list :
261+ reviewed_tool_calls_list = json .loads (reviewed_outputs_json )
262+ if not reviewed_tool_calls_list :
201263 return {}
202264
265+ # Track if tool calls were successfully processed
266+ tool_calls_processed = False
267+
203268 # For AI messages, process tool calls if present
204269 if isinstance (last_message , AIMessage ):
205270 ai_message : AIMessage = last_message
206- content_index = 0
207271
208- if ai_message .tool_calls :
272+ if ai_message .tool_calls and isinstance ( reviewed_tool_calls_list , list ) :
209273 tool_calls = list (ai_message .tool_calls )
210- for tool_call in tool_calls :
211- args = tool_call ["args" ]
274+
275+ # Create a name-to-args mapping from reviewed tool call data
276+ reviewed_tool_calls_map = {}
277+ for reviewed_data in reviewed_tool_calls_list :
212278 if (
213- isinstance (args , dict )
214- and "content " in args
215- and args [ "content" ] is not None
279+ isinstance (reviewed_data , dict )
280+ and "name " in reviewed_data
281+ and " args" in reviewed_data
216282 ):
217- if content_index < len (content_list ):
218- updated_content = json .loads (
219- content_list [content_index ]
220- )
221- args ["content" ] = updated_content
222- tool_call ["args" ] = args
223- content_index += 1
224- ai_message .tool_calls = tool_calls
225-
226- if len (content_list ) > content_index :
227- ai_message .content = content_list [- 1 ]
228- else :
229- # Fallback for other message types
230- if content_list :
231- last_message .content = content_list [- 1 ]
283+ reviewed_tool_calls_map [reviewed_data ["name" ]] = (
284+ reviewed_data ["args" ]
285+ )
286+
287+ # Update tool calls with reviewed args by matching name
288+ if reviewed_tool_calls_map :
289+ for tool_call in tool_calls :
290+ tool_name = (
291+ tool_call .get ("name" )
292+ if isinstance (tool_call , dict )
293+ else getattr (tool_call , "name" , None )
294+ )
295+ if tool_name and tool_name in reviewed_tool_calls_map :
296+ if isinstance (tool_call , dict ):
297+ tool_call ["args" ] = reviewed_tool_calls_map [
298+ tool_name
299+ ]
300+ else :
301+ tool_call .args = reviewed_tool_calls_map [tool_name ]
302+
303+ ai_message .tool_calls = tool_calls
304+ tool_calls_processed = True
305+
306+ # Fallback: update message content if tool_calls weren't processed
307+ if not tool_calls_processed :
308+ last_message .content = reviewed_outputs_json
232309
233310 return Command (update = {"messages" : msgs })
234311 except Exception as e :
@@ -326,97 +403,80 @@ def _process_tool_escalation_response(
326403
327404
328405def _extract_escalation_content (
329- state : AgentGuardrailsGraphState ,
406+ message : BaseMessage ,
330407 scope : GuardrailScope ,
331408 execution_stage : ExecutionStage ,
332409 guarded_node_name : str ,
333410) -> str | list [str | Dict [str , Any ]]:
334- """Extract escalation content from state based on guardrail scope and execution stage.
411+ """Extract escalation content from a message based on guardrail scope and execution stage.
335412
336413 Args:
337- state : The current agent graph state .
414+ message : The message to extract content from .
338415 scope: The guardrail scope (LLM/AGENT/TOOL).
339416 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
417+ guarded_node_name: Name of the guarded component.
340418
341419 Returns:
342420 str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
343421 For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
344-
345- Raises:
346- AgentTerminationException: If no messages are found in state.
347422 """
348- if not state .messages :
349- raise AgentTerminationException (
350- code = UiPathErrorCode .EXECUTION_ERROR ,
351- title = "Invalid state message" ,
352- detail = "No message found into agent state" ,
353- )
354-
355423 match scope :
356424 case GuardrailScope .LLM :
357- return _extract_llm_escalation_content (state , execution_stage )
425+ return _extract_llm_escalation_content (message , execution_stage )
358426 case GuardrailScope .AGENT :
359- return _extract_agent_escalation_content (state , execution_stage )
427+ return _extract_agent_escalation_content (message , execution_stage )
360428 case GuardrailScope .TOOL :
361429 return _extract_tool_escalation_content (
362- state , execution_stage , guarded_node_name
430+ message , execution_stage , guarded_node_name
363431 )
364432
365433
366434def _extract_llm_escalation_content (
367- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage
435+ message : BaseMessage , execution_stage : ExecutionStage
368436) -> str | list [str | Dict [str , Any ]]:
369437 """Extract escalation content for LLM scope guardrails.
370438
371439 Args:
372- state : The current agent graph state .
440+ message : The message to extract content from .
373441 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
374442
375443 Returns:
376444 str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with message content or empty string.
377445 For PostExecution, returns JSON string (array) with tool call content and message content.
378446 Returns empty string if no content found.
379447 """
380- last_message = state .messages [- 1 ]
381448 if execution_stage == ExecutionStage .PRE_EXECUTION :
382- if isinstance (last_message , ToolMessage ):
383- return last_message .content
449+ if isinstance (message , ToolMessage ):
450+ return message .content
384451
385- content = get_message_content (last_message )
452+ content = get_message_content (cast ( AnyMessage , message ) )
386453 return json .dumps (content ) if content else ""
387454
388455 # For AI messages, process tool calls if present
389- if isinstance (last_message , AIMessage ):
390- ai_message : AIMessage = last_message
391- content_list : list [str ] = []
456+ if isinstance (message , AIMessage ):
457+ ai_message : AIMessage = message
392458
393459 if ai_message .tool_calls :
460+ content_list : list [Dict [str , Any ]] = []
394461 for tool_call in ai_message .tool_calls :
395- args = tool_call ["args" ]
396- if (
397- isinstance (args , dict )
398- and "content" in args
399- and args ["content" ] is not None
400- ):
401- content_list .append (json .dumps (args ["content" ]))
402-
403- message_content = get_message_content (last_message )
404- if message_content :
405- content_list .append (message_content )
406-
407- return json .dumps (content_list )
462+ tool_call_data = {
463+ "name" : tool_call .get ("name" ),
464+ "args" : tool_call .get ("args" ),
465+ }
466+ content_list .append (tool_call_data )
467+ return json .dumps (content_list )
408468
409469 # Fallback for other message types
410- return get_message_content (last_message )
470+ return get_message_content (cast ( AnyMessage , message ) )
411471
412472
413473def _extract_agent_escalation_content (
414- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage
474+ message : BaseMessage , execution_stage : ExecutionStage
415475) -> str | list [str | Dict [str , Any ]]:
416476 """Extract escalation content for AGENT scope guardrails.
417477
418478 Args:
419- state : The current agent graph state .
479+ message : The message to extract content from .
420480 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
421481
422482 Returns:
@@ -426,12 +486,12 @@ def _extract_agent_escalation_content(
426486
427487
428488def _extract_tool_escalation_content (
429- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage , tool_name : str
489+ message : BaseMessage , execution_stage : ExecutionStage , tool_name : str
430490) -> str | list [str | Dict [str , Any ]]:
431491 """Extract escalation content for TOOL scope guardrails.
432492
433493 Args:
434- state : The current agent graph state .
494+ message : The message to extract content from .
435495 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
436496 tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
437497
@@ -440,16 +500,31 @@ def _extract_tool_escalation_content(
440500 for the specified tool name, or empty string if not found. For PostExecution, returns string with
441501 tool message content, or empty string if message type doesn't match.
442502 """
443- last_message = state .messages [- 1 ]
444503 if execution_stage == ExecutionStage .PRE_EXECUTION :
445- args = _extract_tool_args_from_message (last_message , tool_name )
504+ args = _extract_tool_args_from_message (cast ( AnyMessage , message ) , tool_name )
446505 if args :
447506 return json .dumps (args )
448507 return ""
449508 else :
450- if not isinstance (last_message , ToolMessage ):
509+ if not isinstance (message , ToolMessage ):
451510 return ""
452- return last_message .content
511+ content = message .content
512+
513+ # If content is already dict/list, serialize to JSON
514+ if isinstance (content , (dict , list )):
515+ return json .dumps (content )
516+
517+ # If content is a string that looks like a Python literal, convert to JSON
518+ if isinstance (content , str ):
519+ try :
520+ # Try to parse as Python literal and convert to JSON
521+ parsed_content = ast .literal_eval (content )
522+ return json .dumps (parsed_content )
523+ except (ValueError , SyntaxError ):
524+ # If parsing fails, return as-is
525+ pass
526+
527+ return content
453528
454529
455530def _execution_stage_to_escalation_field (
0 commit comments