|
22 | 22 | from django.db.models import QuerySet, Q |
23 | 23 | from django.http import HttpResponse |
24 | 24 | from django.utils import timezone |
25 | | -from django.utils.translation import gettext_lazy as _ |
| 25 | +from django.utils.translation import gettext_lazy as _, gettext |
26 | 26 | from rest_framework import serializers, status |
27 | 27 | from rest_framework.utils.formatting import lazy_format |
28 | 28 |
|
|
52 | 52 | tool_executor = ToolExecutor() |
53 | 53 |
|
54 | 54 |
|
| 55 | +def is_valid_tool_workflow_circular_dependency(workflow, _id, visited=None, stack=None): |
| 56 | + """ |
| 57 | + workflow: 当前要检查的 workflow 对象 |
| 58 | + visited: 全局已经访问过的 workflow id |
| 59 | + stack: 当前递归栈里的 workflow id |
| 60 | + """ |
| 61 | + if visited is None: |
| 62 | + visited = set() |
| 63 | + if stack is None: |
| 64 | + stack = set() |
| 65 | + |
| 66 | + if _id in stack: |
| 67 | + return False |
| 68 | + |
| 69 | + if _id in visited: |
| 70 | + return True |
| 71 | + |
| 72 | + stack.add(_id) |
| 73 | + |
| 74 | + for node in workflow.get('nodes', []): |
| 75 | + child_tool_ids = [] |
| 76 | + if node.get('type') == 'ai-chat-node': |
| 77 | + node_data = node.get('properties', {}).get('node_data', {}) |
| 78 | + child_tool_ids = node_data.get('tool_ids') or [] |
| 79 | + if node.get('type') == 'tool-workflow-lib-node': |
| 80 | + child_tool_id = node.get('properties', {}).get('node_data', {}).get('tool_lib_id') |
| 81 | + child_tool_ids.append(child_tool_id) |
| 82 | + for child_tool_id in child_tool_ids: |
| 83 | + if child_tool_id: |
| 84 | + child_workflow = QuerySet(ToolWorkflow).filter(tool_id=child_tool_id).first() |
| 85 | + if child_workflow: |
| 86 | + if not is_valid_tool_workflow_circular_dependency(child_workflow.work_flow, str(child_tool_id), |
| 87 | + visited, |
| 88 | + stack): |
| 89 | + return False |
| 90 | + |
| 91 | + stack.remove(_id) |
| 92 | + visited.add(_id) |
| 93 | + return True |
| 94 | + |
| 95 | + |
55 | 96 | def hand_node(node, update_tool_map): |
56 | 97 | if node.get('type') == 'tool-lib-node': |
57 | 98 | tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') |
@@ -233,6 +274,10 @@ def edit(self, instance: Dict): |
233 | 274 | tool = QuerySet(Tool).filter(id=self.data.get("tool_id")).first() |
234 | 275 | workflow_id = tool.workspace_id |
235 | 276 | if instance.get("work_flow"): |
| 277 | + dependency = is_valid_tool_workflow_circular_dependency(workflow=instance.get('work_flow'), |
| 278 | + _id=str(tool.id)) |
| 279 | + if not dependency: |
| 280 | + raise Exception(gettext('There is a circular dependency in the tool workflow')) |
236 | 281 | QuerySet(ToolWorkflow).update_or_create(tool_id=self.data.get("tool_id"), |
237 | 282 | create_defaults={'id': uuid.uuid7(), |
238 | 283 | 'tool_id': self.data.get( |
|
0 commit comments