22
33import json
44import re
5- from typing import Any , Dict , Literal
5+ from typing import Any , Dict , Literal , cast
66
7- from langchain_core .messages import AIMessage , ToolMessage
7+ from langchain_core .messages import AIMessage , AnyMessage , BaseMessage , ToolMessage
88from langgraph .types import Command , interrupt
99from uipath .platform .common import CreateEscalation
1010from uipath .platform .guardrails import (
@@ -72,20 +72,48 @@ def action_node(
7272 async def _node (
7373 state : AgentGuardrailsGraphState ,
7474 ) -> Dict [str , Any ] | Command [Any ]:
75- input = _extract_escalation_content (
76- state , scope , execution_stage , guarded_component_name
77- )
78- escalation_field = _execution_stage_to_escalation_field (execution_stage )
75+ # Validate message count based on execution stage
76+ _validate_message_count (state , execution_stage )
7977
80- data = {
78+ # Build base data dictionary with common fields
79+ data : Dict [str , Any ] = {
8180 "GuardrailName" : guardrail .name ,
8281 "GuardrailDescription" : guardrail .description ,
8382 "Component" : scope .name .lower (),
8483 "ExecutionStage" : _execution_stage_to_string (execution_stage ),
8584 "GuardrailResult" : state .guardrail_validation_result ,
86- escalation_field : input ,
8785 }
8886
87+ # Add stage-specific fields
88+ if execution_stage == ExecutionStage .PRE_EXECUTION :
89+ # PRE_EXECUTION: Only Inputs field from last message
90+ input_content = _extract_escalation_content (
91+ state .messages [- 1 ],
92+ scope ,
93+ execution_stage ,
94+ guarded_component_name ,
95+ )
96+ data ["Inputs" ] = input_content
97+ else : # POST_EXECUTION
98+ # Extract Inputs from second-to-last message using PRE_EXECUTION logic
99+ input_content = _extract_escalation_content (
100+ state .messages [- 2 ],
101+ scope ,
102+ ExecutionStage .PRE_EXECUTION ,
103+ guarded_component_name ,
104+ )
105+
106+ # Extract Outputs from last message using POST_EXECUTION logic
107+ output_content = _extract_escalation_content (
108+ state .messages [- 1 ],
109+ scope ,
110+ execution_stage ,
111+ guarded_component_name ,
112+ )
113+
114+ data ["Inputs" ] = input_content
115+ data ["Outputs" ] = output_content
116+
89117 escalation_result = interrupt (
90118 CreateEscalation (
91119 app_name = self .app_name ,
@@ -114,6 +142,39 @@ async def _node(
114142 return node_name , _node
115143
116144
145+ def _validate_message_count (
146+ state : AgentGuardrailsGraphState ,
147+ execution_stage : ExecutionStage ,
148+ ) -> None :
149+ """Validate that state has the required number of messages for the execution stage.
150+
151+ Args:
152+ state: The current agent graph state.
153+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
154+
155+ Raises:
156+ AgentTerminationException: If the state doesn't have enough messages.
157+ """
158+ required_messages = 1 if execution_stage == ExecutionStage .PRE_EXECUTION else 2
159+ actual_messages = len (state .messages )
160+
161+ if actual_messages < required_messages :
162+ stage_name = (
163+ "PRE_EXECUTION"
164+ if execution_stage == ExecutionStage .PRE_EXECUTION
165+ else "POST_EXECUTION"
166+ )
167+ detail = f"{ stage_name } requires at least { required_messages } message{ 's' if required_messages > 1 else '' } in state, but found { actual_messages } ."
168+ if execution_stage == ExecutionStage .POST_EXECUTION :
169+ detail += " Cannot extract Inputs from previous message."
170+
171+ raise AgentTerminationException (
172+ code = UiPathErrorCode .EXECUTION_ERROR ,
173+ title = f"Invalid state for { stage_name } " ,
174+ detail = detail ,
175+ )
176+
177+
117178def _get_node_name (
118179 execution_stage : ExecutionStage , guardrail : BaseGuardrail , scope : GuardrailScope
119180) -> str :
@@ -341,68 +402,58 @@ def _process_tool_escalation_response(
341402
342403
343404def _extract_escalation_content (
344- state : AgentGuardrailsGraphState ,
405+ message : BaseMessage ,
345406 scope : GuardrailScope ,
346407 execution_stage : ExecutionStage ,
347408 guarded_node_name : str ,
348409) -> str | list [str | Dict [str , Any ]]:
349- """Extract escalation content from state based on guardrail scope and execution stage.
410+ """Extract escalation content from a message based on guardrail scope and execution stage.
350411
351412 Args:
352- state : The current agent graph state .
413+ message : The message to extract content from .
353414 scope: The guardrail scope (LLM/AGENT/TOOL).
354415 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
416+ guarded_node_name: Name of the guarded component.
355417
356418 Returns:
357419 str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
358420 For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
359-
360- Raises:
361- AgentTerminationException: If no messages are found in state.
362421 """
363- if not state .messages :
364- raise AgentTerminationException (
365- code = UiPathErrorCode .EXECUTION_ERROR ,
366- title = "Invalid state message" ,
367- detail = "No message found into agent state" ,
368- )
369-
370422 match scope :
371423 case GuardrailScope .LLM :
372- return _extract_llm_escalation_content (state , execution_stage )
424+ return _extract_llm_escalation_content (message , execution_stage )
373425 case GuardrailScope .AGENT :
374- return _extract_agent_escalation_content (state , execution_stage )
426+ return _extract_agent_escalation_content (message , execution_stage )
375427 case GuardrailScope .TOOL :
376428 return _extract_tool_escalation_content (
377- state , execution_stage , guarded_node_name
429+ message , execution_stage , guarded_node_name
378430 )
379431
380432
381433def _extract_llm_escalation_content (
382- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage
434+ message : BaseMessage , execution_stage : ExecutionStage
383435) -> str | list [str | Dict [str , Any ]]:
384436 """Extract escalation content for LLM scope guardrails.
385437
386438 Args:
387- state : The current agent graph state .
439+ message : The message to extract content from .
388440 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
389441
390442 Returns:
391443 str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with message content or empty string.
392444 For PostExecution, returns JSON string (array) with tool call content and message content.
393445 Returns empty string if no content found.
394446 """
395- last_message = state .messages [- 1 ]
396447 if execution_stage == ExecutionStage .PRE_EXECUTION :
397- if isinstance (last_message , ToolMessage ):
398- return last_message .content
448+ if isinstance (message , ToolMessage ):
449+ return message .content
399450
400- content = get_message_content (last_message )
451+ content = get_message_content (cast ( AnyMessage , message ) )
401452 return json .dumps (content ) if content else ""
402453
403454 # For AI messages, process tool calls if present
404- if isinstance (last_message , AIMessage ):
405- ai_message : AIMessage = last_message
455+ if isinstance (message , AIMessage ):
456+ ai_message : AIMessage = message
406457
407458 if ai_message .tool_calls :
408459 content_list : list [Dict [str , Any ]] = []
@@ -415,16 +466,16 @@ def _extract_llm_escalation_content(
415466 return json .dumps (content_list )
416467
417468 # Fallback for other message types
418- return get_message_content (last_message )
469+ return get_message_content (cast ( AnyMessage , message ) )
419470
420471
421472def _extract_agent_escalation_content (
422- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage
473+ message : BaseMessage , execution_stage : ExecutionStage
423474) -> str | list [str | Dict [str , Any ]]:
424475 """Extract escalation content for AGENT scope guardrails.
425476
426477 Args:
427- state : The current agent graph state .
478+ message : The message to extract content from .
428479 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
429480
430481 Returns:
@@ -434,12 +485,12 @@ def _extract_agent_escalation_content(
434485
435486
436487def _extract_tool_escalation_content (
437- state : AgentGuardrailsGraphState , execution_stage : ExecutionStage , tool_name : str
488+ message : BaseMessage , execution_stage : ExecutionStage , tool_name : str
438489) -> str | list [str | Dict [str , Any ]]:
439490 """Extract escalation content for TOOL scope guardrails.
440491
441492 Args:
442- state : The current agent graph state .
493+ message : The message to extract content from .
443494 execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
444495 tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
445496
@@ -448,16 +499,15 @@ def _extract_tool_escalation_content(
448499 for the specified tool name, or empty string if not found. For PostExecution, returns string with
449500 tool message content, or empty string if message type doesn't match.
450501 """
451- last_message = state .messages [- 1 ]
452502 if execution_stage == ExecutionStage .PRE_EXECUTION :
453- args = _extract_tool_args_from_message (last_message , tool_name )
503+ args = _extract_tool_args_from_message (cast ( AnyMessage , message ) , tool_name )
454504 if args :
455505 return json .dumps (args )
456506 return ""
457507 else :
458- if not isinstance (last_message , ToolMessage ):
508+ if not isinstance (message , ToolMessage ):
459509 return ""
460- return last_message .content
510+ return message .content
461511
462512
463513def _execution_stage_to_escalation_field (
0 commit comments