Skip to content

Commit 42d7e5d

Browse files
Merge pull request #347 from UiPath/AL-229-integrate_deterministic_guardrails
feat: integrate deterministic guardrails [AL-229]
2 parents d1e0b46 + 742d229 commit 42d7e5d

8 files changed

Lines changed: 1017 additions & 84 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "Python SDK that enables developers to build and deploy LangGraph
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"
77
dependencies = [
8-
"uipath>=2.2.32, <2.3.0",
8+
"uipath>=2.2.35, <2.3.0",
99
"langgraph>=1.0.0, <2.0.0",
1010
"langchain-core>=1.0.0, <2.0.0",
1111
"aiosqlite==0.21.0",

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from typing import Any, Dict, Literal
66

7-
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
7+
from langchain_core.messages import AIMessage, ToolMessage
88
from langgraph.types import Command, interrupt
99
from uipath.platform.common import CreateEscalation
1010
from uipath.platform.guardrails import (
@@ -16,7 +16,7 @@
1616
from ...exceptions import AgentTerminationException
1717
from ...react.types import AgentGuardrailsGraphState
1818
from ..types import ExecutionStage
19-
from ..utils import get_message_content
19+
from ..utils import _extract_tool_args_from_message, get_message_content
2020
from .base_action import GuardrailAction, GuardrailActionNode
2121

2222

@@ -293,7 +293,11 @@ def _process_tool_escalation_response(
293293
if last_message.tool_calls:
294294
tool_calls = list(last_message.tool_calls)
295295
for tool_call in tool_calls:
296-
call_name = extract_tool_name(tool_call)
296+
call_name = (
297+
tool_call.get("name")
298+
if isinstance(tool_call, dict)
299+
else getattr(tool_call, "name", None)
300+
)
297301
if call_name == tool_name:
298302
# Update args for the matching tool call
299303
if isinstance(reviewed_tool_calls_args, dict):
@@ -438,38 +442,16 @@ def _extract_tool_escalation_content(
438442
"""
439443
last_message = state.messages[-1]
440444
if execution_stage == ExecutionStage.PRE_EXECUTION:
441-
if not isinstance(last_message, AIMessage):
442-
return ""
443-
if not last_message.tool_calls:
444-
return ""
445-
446-
# Find the tool call with matching name
447-
for tool_call in last_message.tool_calls:
448-
call_name = extract_tool_name(tool_call)
449-
if call_name == tool_name:
450-
# Extract args from the matching tool call
451-
args = (
452-
tool_call.get("args")
453-
if isinstance(tool_call, dict)
454-
else getattr(tool_call, "args", None)
455-
)
456-
if args is not None:
457-
return json.dumps(args)
445+
args = _extract_tool_args_from_message(last_message, tool_name)
446+
if args:
447+
return json.dumps(args)
458448
return ""
459449
else:
460450
if not isinstance(last_message, ToolMessage):
461451
return ""
462452
return last_message.content
463453

464454

465-
def extract_tool_name(tool_call: ToolCall) -> Any | None:
466-
return (
467-
tool_call.get("name")
468-
if isinstance(tool_call, dict)
469-
else getattr(tool_call, "name", None)
470-
)
471-
472-
473455
def _execution_stage_to_escalation_field(
474456
execution_stage: ExecutionStage,
475457
) -> str:

src/uipath_langchain/agent/guardrails/guardrail_nodes.py

Lines changed: 139 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,120 @@
33
import re
44
from typing import Any, Callable
55

6-
from langchain_core.messages import AIMessage
76
from langgraph.types import Command
7+
from uipath.core.guardrails import (
8+
DeterministicGuardrail,
9+
DeterministicGuardrailsService,
10+
)
811
from uipath.platform import UiPath
912
from uipath.platform.guardrails import (
1013
BaseGuardrail,
14+
BuiltInValidatorGuardrail,
1115
GuardrailScope,
1216
)
17+
from uipath.runtime.errors import UiPathErrorCode
1318

1419
from uipath_langchain.agent.guardrails.types import ExecutionStage
15-
from uipath_langchain.agent.guardrails.utils import get_message_content
20+
from uipath_langchain.agent.guardrails.utils import (
21+
_extract_tool_input_data,
22+
_extract_tool_output_data,
23+
get_message_content,
24+
)
1625
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
1726

27+
from ..exceptions import AgentTerminationException
28+
1829
logger = logging.getLogger(__name__)
1930

2031

32+
def _evaluate_deterministic_guardrail(
33+
state: AgentGuardrailsGraphState,
34+
guardrail: DeterministicGuardrail,
35+
execution_stage: ExecutionStage,
36+
input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]],
37+
output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None,
38+
):
39+
"""Evaluate deterministic guardrail.
40+
41+
Args:
42+
state: The current agent graph state.
43+
guardrail: The deterministic guardrail to evaluate.
44+
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
45+
input_data_extractor: Function to extract input data from state.
46+
output_data_extractor: Function to extract output data from state (optional).
47+
48+
Returns:
49+
The guardrail evaluation result.
50+
"""
51+
service = DeterministicGuardrailsService()
52+
input_data = input_data_extractor(state)
53+
54+
if execution_stage == ExecutionStage.PRE_EXECUTION:
55+
return service.evaluate_pre_deterministic_guardrail(
56+
input_data=input_data, guardrail=guardrail
57+
)
58+
else: # POST_EXECUTION
59+
output_data = output_data_extractor(state) if output_data_extractor else {}
60+
return service.evaluate_post_deterministic_guardrail(
61+
input_data=input_data,
62+
output_data=output_data,
63+
guardrail=guardrail,
64+
)
65+
66+
67+
def _evaluate_builtin_guardrail(
68+
state: AgentGuardrailsGraphState,
69+
guardrail: BuiltInValidatorGuardrail,
70+
payload_generator: Callable[[AgentGuardrailsGraphState], str],
71+
):
72+
"""Evaluate built-in validator guardrail.
73+
74+
Args:
75+
state: The current agent graph state.
76+
guardrail: The built-in validator guardrail to evaluate.
77+
payload_generator: Function to generate payload text from state.
78+
79+
Returns:
80+
The guardrail evaluation result.
81+
"""
82+
text = payload_generator(state)
83+
uipath = UiPath()
84+
return uipath.guardrails.evaluate_guardrail(text, guardrail)
85+
86+
87+
def _create_validation_command(
88+
result,
89+
success_node: str,
90+
failure_node: str,
91+
) -> Command[Any]:
92+
"""Create command based on validation result.
93+
94+
Args:
95+
result: The guardrail evaluation result.
96+
success_node: Node to route to on validation pass.
97+
failure_node: Node to route to on validation fail.
98+
99+
Returns:
100+
Command to update state and route to appropriate node.
101+
"""
102+
if not result.validation_passed:
103+
return Command(
104+
goto=failure_node, update={"guardrail_validation_result": result.reason}
105+
)
106+
return Command(goto=success_node, update={"guardrail_validation_result": None})
107+
108+
21109
def _create_guardrail_node(
22110
guardrail: BaseGuardrail,
23111
scope: GuardrailScope,
24112
execution_stage: ExecutionStage,
25113
payload_generator: Callable[[AgentGuardrailsGraphState], str],
26114
success_node: str,
27115
failure_node: str,
116+
input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
117+
| None = None,
118+
output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
119+
| None = None,
28120
) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
29121
"""Private factory for guardrail evaluation nodes.
30122
@@ -38,19 +130,41 @@ def _create_guardrail_node(
38130
async def node(
39131
state: AgentGuardrailsGraphState,
40132
):
41-
text = payload_generator(state)
42133
try:
43-
uipath = UiPath()
44-
result = uipath.guardrails.evaluate_guardrail(text, guardrail)
45-
except Exception as exc:
46-
logger.error("Failed to evaluate guardrail: %s", exc)
47-
raise
134+
# Route to appropriate evaluation service based on guardrail type and scope
135+
if (
136+
isinstance(guardrail, DeterministicGuardrail)
137+
and scope == GuardrailScope.TOOL
138+
and input_data_extractor is not None
139+
):
140+
result = _evaluate_deterministic_guardrail(
141+
state,
142+
guardrail,
143+
execution_stage,
144+
input_data_extractor,
145+
output_data_extractor,
146+
)
147+
elif isinstance(guardrail, BuiltInValidatorGuardrail):
148+
result = _evaluate_builtin_guardrail(
149+
state, guardrail, payload_generator
150+
)
151+
else:
152+
raise AgentTerminationException(
153+
code=UiPathErrorCode.EXECUTION_ERROR,
154+
title="Unsupported guardrail type",
155+
detail=f"Guardrail type '{type(guardrail).__name__}' is not supported. "
156+
f"Expected DeterministicGuardrail or BuiltInValidatorGuardrail.",
157+
)
158+
159+
return _create_validation_command(result, success_node, failure_node)
48160

49-
if not result.validation_passed:
50-
return Command(
51-
goto=failure_node, update={"guardrail_validation_result": result.reason}
161+
except Exception as exc:
162+
logger.error(
163+
"Failed to evaluate guardrail '%s': %s",
164+
guardrail.name,
165+
exc,
52166
)
53-
return Command(goto=success_node, update={"guardrail_validation_result": None})
167+
raise
54168

55169
return node_name, node
56170

@@ -149,37 +263,27 @@ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
149263
return ""
150264

151265
if execution_stage == ExecutionStage.PRE_EXECUTION:
152-
if not isinstance(state.messages[-1], AIMessage):
153-
return ""
154-
message = state.messages[-1]
155-
156-
if not message.tool_calls:
157-
return ""
158-
159-
# Find the first tool call with matching name
160-
for tool_call in message.tool_calls:
161-
call_name = (
162-
tool_call.get("name")
163-
if isinstance(tool_call, dict)
164-
else getattr(tool_call, "name", None)
165-
)
166-
if call_name == tool_name:
167-
# Extract args from the tool call
168-
args = (
169-
tool_call.get("args")
170-
if isinstance(tool_call, dict)
171-
else getattr(tool_call, "args", None)
172-
)
173-
if args is not None:
174-
return json.dumps(args)
266+
# Extract tool args as dict and convert to JSON string
267+
args_dict = _extract_tool_input_data(state, tool_name, execution_stage)
268+
if args_dict:
269+
return json.dumps(args_dict)
175270

176271
return get_message_content(state.messages[-1])
177272

273+
# Create closures for input/output data extraction (for deterministic guardrails)
274+
def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
275+
return _extract_tool_input_data(state, tool_name, execution_stage)
276+
277+
def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
278+
return _extract_tool_output_data(state)
279+
178280
return _create_guardrail_node(
179281
guardrail,
180282
GuardrailScope.TOOL,
181283
execution_stage,
182284
_payload_generator,
183285
success_node,
184286
failure_node,
287+
_input_data_extractor,
288+
_output_data_extractor,
185289
)

0 commit comments

Comments
 (0)