Skip to content

Commit b3c2638

Browse files
fix: sent input on pre execution
1 parent d4f204a commit b3c2638

3 files changed

Lines changed: 321 additions & 84 deletions

File tree

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import json
44
import re
5-
from typing import Any, Dict, Literal
5+
from typing import Any, Dict, Literal, cast
66

7-
from langchain_core.messages import AIMessage, ToolMessage
7+
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage
88
from langgraph.types import Command, interrupt
99
from uipath.platform.common import CreateEscalation
1010
from uipath.platform.guardrails import (
@@ -72,20 +72,48 @@ def action_node(
7272
async def _node(
7373
state: AgentGuardrailsGraphState,
7474
) -> Dict[str, Any] | Command[Any]:
75-
input = _extract_escalation_content(
76-
state, scope, execution_stage, guarded_component_name
77-
)
78-
escalation_field = _execution_stage_to_escalation_field(execution_stage)
75+
# Validate message count based on execution stage
76+
_validate_message_count(state, execution_stage)
7977

80-
data = {
78+
# Build base data dictionary with common fields
79+
data: Dict[str, Any] = {
8180
"GuardrailName": guardrail.name,
8281
"GuardrailDescription": guardrail.description,
8382
"Component": scope.name.lower(),
8483
"ExecutionStage": _execution_stage_to_string(execution_stage),
8584
"GuardrailResult": state.guardrail_validation_result,
86-
escalation_field: input,
8785
}
8886

87+
# Add stage-specific fields
88+
if execution_stage == ExecutionStage.PRE_EXECUTION:
89+
# PRE_EXECUTION: Only Inputs field from last message
90+
input_content = _extract_escalation_content(
91+
state.messages[-1],
92+
scope,
93+
execution_stage,
94+
guarded_component_name,
95+
)
96+
data["Inputs"] = input_content
97+
else: # POST_EXECUTION
98+
# Extract Inputs from second-to-last message using PRE_EXECUTION logic
99+
input_content = _extract_escalation_content(
100+
state.messages[-2],
101+
scope,
102+
ExecutionStage.PRE_EXECUTION,
103+
guarded_component_name,
104+
)
105+
106+
# Extract Outputs from last message using POST_EXECUTION logic
107+
output_content = _extract_escalation_content(
108+
state.messages[-1],
109+
scope,
110+
execution_stage,
111+
guarded_component_name,
112+
)
113+
114+
data["Inputs"] = input_content
115+
data["Outputs"] = output_content
116+
89117
escalation_result = interrupt(
90118
CreateEscalation(
91119
app_name=self.app_name,
@@ -114,6 +142,39 @@ async def _node(
114142
return node_name, _node
115143

116144

145+
def _validate_message_count(
146+
state: AgentGuardrailsGraphState,
147+
execution_stage: ExecutionStage,
148+
) -> None:
149+
"""Validate that state has the required number of messages for the execution stage.
150+
151+
Args:
152+
state: The current agent graph state.
153+
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
154+
155+
Raises:
156+
AgentTerminationException: If the state doesn't have enough messages.
157+
"""
158+
required_messages = 1 if execution_stage == ExecutionStage.PRE_EXECUTION else 2
159+
actual_messages = len(state.messages)
160+
161+
if actual_messages < required_messages:
162+
stage_name = (
163+
"PRE_EXECUTION"
164+
if execution_stage == ExecutionStage.PRE_EXECUTION
165+
else "POST_EXECUTION"
166+
)
167+
detail = f"{stage_name} requires at least {required_messages} message{'s' if required_messages > 1 else ''} in state, but found {actual_messages}."
168+
if execution_stage == ExecutionStage.POST_EXECUTION:
169+
detail += " Cannot extract Inputs from previous message."
170+
171+
raise AgentTerminationException(
172+
code=UiPathErrorCode.EXECUTION_ERROR,
173+
title=f"Invalid state for {stage_name}",
174+
detail=detail,
175+
)
176+
177+
117178
def _get_node_name(
118179
execution_stage: ExecutionStage, guardrail: BaseGuardrail, scope: GuardrailScope
119180
) -> str:
@@ -341,68 +402,58 @@ def _process_tool_escalation_response(
341402

342403

343404
def _extract_escalation_content(
344-
state: AgentGuardrailsGraphState,
405+
message: BaseMessage,
345406
scope: GuardrailScope,
346407
execution_stage: ExecutionStage,
347408
guarded_node_name: str,
348409
) -> str | list[str | Dict[str, Any]]:
349-
"""Extract escalation content from state based on guardrail scope and execution stage.
410+
"""Extract escalation content from a message based on guardrail scope and execution stage.
350411
351412
Args:
352-
state: The current agent graph state.
413+
message: The message to extract content from.
353414
scope: The guardrail scope (LLM/AGENT/TOOL).
354415
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
416+
guarded_node_name: Name of the guarded component.
355417
356418
Returns:
357419
str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
358420
For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
359-
360-
Raises:
361-
AgentTerminationException: If no messages are found in state.
362421
"""
363-
if not state.messages:
364-
raise AgentTerminationException(
365-
code=UiPathErrorCode.EXECUTION_ERROR,
366-
title="Invalid state message",
367-
detail="No message found into agent state",
368-
)
369-
370422
match scope:
371423
case GuardrailScope.LLM:
372-
return _extract_llm_escalation_content(state, execution_stage)
424+
return _extract_llm_escalation_content(message, execution_stage)
373425
case GuardrailScope.AGENT:
374-
return _extract_agent_escalation_content(state, execution_stage)
426+
return _extract_agent_escalation_content(message, execution_stage)
375427
case GuardrailScope.TOOL:
376428
return _extract_tool_escalation_content(
377-
state, execution_stage, guarded_node_name
429+
message, execution_stage, guarded_node_name
378430
)
379431

380432

381433
def _extract_llm_escalation_content(
382-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
434+
message: BaseMessage, execution_stage: ExecutionStage
383435
) -> str | list[str | Dict[str, Any]]:
384436
"""Extract escalation content for LLM scope guardrails.
385437
386438
Args:
387-
state: The current agent graph state.
439+
message: The message to extract content from.
388440
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
389441
390442
Returns:
391443
str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with message content or empty string.
392444
For PostExecution, returns JSON string (array) with tool call content and message content.
393445
Returns empty string if no content found.
394446
"""
395-
last_message = state.messages[-1]
396447
if execution_stage == ExecutionStage.PRE_EXECUTION:
397-
if isinstance(last_message, ToolMessage):
398-
return last_message.content
448+
if isinstance(message, ToolMessage):
449+
return message.content
399450

400-
content = get_message_content(last_message)
451+
content = get_message_content(cast(AnyMessage, message))
401452
return json.dumps(content) if content else ""
402453

403454
# For AI messages, process tool calls if present
404-
if isinstance(last_message, AIMessage):
405-
ai_message: AIMessage = last_message
455+
if isinstance(message, AIMessage):
456+
ai_message: AIMessage = message
406457

407458
if ai_message.tool_calls:
408459
content_list: list[Dict[str, Any]] = []
@@ -415,16 +466,16 @@ def _extract_llm_escalation_content(
415466
return json.dumps(content_list)
416467

417468
# Fallback for other message types
418-
return get_message_content(last_message)
469+
return get_message_content(cast(AnyMessage, message))
419470

420471

421472
def _extract_agent_escalation_content(
422-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
473+
message: BaseMessage, execution_stage: ExecutionStage
423474
) -> str | list[str | Dict[str, Any]]:
424475
"""Extract escalation content for AGENT scope guardrails.
425476
426477
Args:
427-
state: The current agent graph state.
478+
message: The message to extract content from.
428479
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
429480
430481
Returns:
@@ -434,12 +485,12 @@ def _extract_agent_escalation_content(
434485

435486

436487
def _extract_tool_escalation_content(
437-
state: AgentGuardrailsGraphState, execution_stage: ExecutionStage, tool_name: str
488+
message: BaseMessage, execution_stage: ExecutionStage, tool_name: str
438489
) -> str | list[str | Dict[str, Any]]:
439490
"""Extract escalation content for TOOL scope guardrails.
440491
441492
Args:
442-
state: The current agent graph state.
493+
message: The message to extract content from.
443494
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
444495
tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
445496
@@ -448,16 +499,15 @@ def _extract_tool_escalation_content(
448499
for the specified tool name, or empty string if not found. For PostExecution, returns string with
449500
tool message content, or empty string if message type doesn't match.
450501
"""
451-
last_message = state.messages[-1]
452502
if execution_stage == ExecutionStage.PRE_EXECUTION:
453-
args = _extract_tool_args_from_message(last_message, tool_name)
503+
args = _extract_tool_args_from_message(cast(AnyMessage, message), tool_name)
454504
if args:
455505
return json.dumps(args)
456506
return ""
457507
else:
458-
if not isinstance(last_message, ToolMessage):
508+
if not isinstance(message, ToolMessage):
459509
return ""
460-
return last_message.content
510+
return message.content
461511

462512

463513
def _execution_stage_to_escalation_field(

src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _create_guardrails_subgraph(
116116
ExecutionStage.POST_EXECUTION,
117117
node_factory,
118118
END,
119-
inner_node,
119+
inner_name,
120120
)
121121
subgraph.add_edge(inner_name, first_post_exec_guardrail_node)
122122
else:

0 commit comments

Comments
 (0)