Skip to content
This repository was archived by the owner on Jan 5, 2026. It is now read-only.

Commit bcbeab1

Browse files
authored
auto layout the agent interaction graph (#86)
1 parent ff086e9 commit bcbeab1

1 file changed

Lines changed: 40 additions & 161 deletions

File tree

visualization.py

Lines changed: 40 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,198 +1,77 @@
11
import pandas as pd
22
from matplotlib import pyplot as plt
3-
from anytree import Node, RenderTree
4-
from anytree.exporter import DotExporter
53
import json
4+
import networkx as nx
65

7-
# Load the data
6+
# Load data
87
file_path_latest = 'log/multiagent_data_20241030.csv'
9-
task_id = 1
8+
task_id = 3
109
meta_task_id = 'master_agent_task'
1110
data_latest = pd.read_csv(file_path_latest)
1211

13-
# Filter data for task_id = 1 and a specific meta_task_id
14-
task_1_data = data_latest[(data_latest['task_id'] == task_id) &
12+
# Filter data
13+
task_data = data_latest[(data_latest['task_id'] == task_id) &
1514
(data_latest['meta_task_id'] == meta_task_id)]
1615

17-
18-
# Clear previous nodes and dictionary
19-
nodes_dict = {}
20-
21-
# Create the root node (Main Agent at the top level)
22-
root_node = Node(f"Main Agent")
23-
nodes_dict['main_agent'] = root_node
24-
25-
## change agent name
26-
27-
# Iterate through task_1_data to build the hierarchy based on agent interactions
28-
last_agent_node = root_node
29-
30-
import pandas as pd
31-
import json
32-
import networkx as nx
33-
import matplotlib.pyplot as plt
34-
import json
35-
36-
# Create a directed graph
16+
# Initialize graph and add root nodes
3717
G = nx.MultiDiGraph()
38-
# Assuming 'task_1_data' is your DataFrame
39-
formatted_data = []
40-
# root_node = Node("main_agent")
41-
mapping = {}
4218
G.add_node('user_request')
4319
G.add_node('main_agent')
44-
count = 0
45-
# Add initial edge between 'user_request' and 'main_agent'
46-
G.add_edge('user_request', 'main_agent', key=count, label=f'step {count}: user prompting main_agent')
47-
mapping['main_agent'] = 'user_request'
48-
count += 1
20+
G.add_edge('user_request', 'main_agent', key=0, label='step 0: user prompting main_agent')
4921

50-
for _, row in task_1_data.iterrows():
51-
print(count)
22+
# Helper to map agent to its parent node
23+
parent_mapping = {'main_agent': 'user_request'}
24+
25+
# Build graph based on agent interactions
26+
step = 1
27+
for _, row in task_data.iterrows():
5228
agent = row['agent']
5329
depth = row['depth']
5430
response = json.loads(row['response'])
55-
tool_calls = response['tool_calls']
56-
57-
if agent not in G:
58-
print("error")
59-
31+
tool_calls = response.get('tool_calls', None)
6032
if tool_calls:
6133
for tool_call in tool_calls:
6234
tool = tool_call['function']['name']
63-
if 'agent' in tool:
64-
formatted_data.append({
65-
'agent': agent,
66-
'depth': depth,
67-
'action': f'calling sub agent: {tool}'
68-
})
69-
G.add_edge(agent, tool, key=count, label=f'step {count}: agent {agent} calling sub agent {tool}')
70-
mapping[tool] = agent
71-
count += 1
72-
else:
73-
formatted_data.append({
74-
'agent': agent,
75-
'depth': depth,
76-
'action': f'calling tool: {tool}'
77-
})
78-
G.add_edge(agent, tool, key=count, label=f'step {count}: agent {agent} calling tool {tool}')
79-
count += 1
35+
is_agent = 'agent' in tool
36+
action = f'calling {"sub agent" if is_agent else "tool"}: {tool}'
37+
38+
# Add formatted data
39+
if is_agent:
40+
parent_mapping[tool] = agent
41+
G.add_edge(agent, tool, key=step, label=f'step {step}: {action}')
42+
step += 1
8043
else:
81-
formatted_data.append({
82-
'agent': agent,
83-
'depth': depth,
84-
'action': 'go back to parent node'
85-
})
86-
G.add_edge(agent, mapping[agent], key=count, label=f'step {count}: sub agent {agent} going back to parent agent {mapping[agent]}')
87-
count += 1
88-
89-
# import pygraphviz as pgv
90-
# from networkx.drawing.nx_agraph import to_agraph
44+
G.add_edge(agent, parent_mapping[agent], key=step, label=f'step {step}: sub agent {agent} going back to parent agent {parent_mapping[agent]}')
45+
step += 1
9146

9247
print(G.edges)
93-
import networkx as nx
94-
import matplotlib.pyplot as plt
95-
96-
layout_configs = {
97-
1: {
98-
0: [('user_request', 0.5)],
99-
1: [('main_agent', 0.5)],
100-
2: [('io_agent', 0.5)],
101-
3: [('generate_and_download_image', 0.5)],
102-
},
103-
2: {
104-
0: [('user_request', 0.5)],
105-
1: [('main_agent', 0.5)],
106-
2: [('exec_agent', 0.5)],
107-
3: [('run_python_script', 0.3), ('execute_shell_command', 0.7)],
108-
},
109-
3: {
110-
0: [('user_request', 0.5)],
111-
1: [('main_agent', 0.5)],
112-
2: [('retrieval_agent', 0.4), ('io_agent', 0.6)],
113-
3: [('web_retrieval_agent', 0.4), ('write_to_file', 0.6)],
114-
4: [('bing_search', 0.3), ('scrape', 0.5)]
115-
},
116-
}
117-
11848

11949
def plot_hierarchical_multi_edge_graph(G, output_file=f'log/agent_interaction_graph{task_id}.png'):
12050
plt.figure(figsize=(12, 8))
121-
122-
# Define levels with adjusted positions for a more compact layout
123-
# Define levels with adjusted positions for a more compact layout
124-
# levels = {
125-
# 0: [('user_request', 0.5)],
126-
# 1: [('main_agent', 0.5)],
127-
# 2: [('use_io_agent', 0.2), ('use_exec_agent', 0.4), ('use_retrieval_agent', 0.6), ('use_structure_agent', 0.8)],
128-
# 3: [('write_to_file', 0.1), ('read_file', 0.3), ('run_python_script_exec', 0.4),
129-
# ('run_shell_script_exec', 0.5),
130-
# ('use_web_retrieval_agent', 0.6), ('use_db_retrieval_agent', 0.7), ('use_file_structure_agent', 0.8),
131-
# ('use_code_structure_agent', 0.9)],
132-
# 4: [('bing_search', 0.55), ('scrape', 0.65)]
133-
# }
134-
# Define levels with adjusted positions for a more compact layout
135-
# levels = {
136-
# 0: [('user_request', 0.5)],
137-
# 1: [('main_agent', 0.5)],
138-
# 2: [('exec_agent', 0.5)],
139-
# 3: [('run_python_script', 0.3), ('execute_shell_command', 0.7)],
140-
# }
141-
142-
# Define levels with adjusted positions for a more compact layout
143-
# levels = {
144-
# 0: [('user_request', 0.5)],
145-
# 1: [('main_agent', 0.5)],
146-
# 2: [('io_agent', 0.5)],
147-
# 3: [('generate_and_download_image', 0.5)],
148-
# }
149-
150-
levels = layout_configs[task_id]
151-
152-
# Calculate positions
153-
pos = {}
154-
for level, nodes in levels.items():
155-
y = 1 - (level / 4) # Adjust vertical spacing
156-
for node, x in nodes:
157-
pos[node] = (x, y)
158-
51+
pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
52+
print(pos)
15953
# Draw nodes
160-
for level, nodes in levels.items():
161-
nx.draw_networkx_nodes(G, pos, nodelist=[n for n, _ in nodes], node_size=2000,
162-
node_color=['lightgreen' if n == 'user_request'
163-
else 'lightyellow' if n == 'main_agent'
164-
else 'lightblue' for n, _ in nodes])
165-
nx.draw_networkx_labels(G, pos, {n: n for n, _ in nodes}, font_size=8, font_weight='bold')
166-
167-
# Draw directed edges with reduced curvature and all labels
168-
edge_labels = {}
169-
for (u, v, key, data) in G.edges(keys=True, data=True):
170-
if u in pos and v in pos:
171-
rad = 0.15 + (0.05 * key)
172-
if key % 2 == 0:
173-
rad *= -1
174-
175-
# Draw directed edge
176-
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], connectionstyle=f'arc3,rad={rad}',
177-
edge_color='gray', arrows=True, arrowsize=15,
178-
arrowstyle='->', width=1)
179-
180-
label = data.get('label', '')
181-
edge_labels[(u, v, key)] = label
182-
183-
# Draw edge labels with adjusted positions
54+
colors = ['lightgreen' if n == 'user_request'
55+
else 'lightyellow' if n == 'main_agent'
56+
else 'lightblue' for n in G.nodes()]
57+
nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=3000)
58+
nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold')
59+
60+
# Draw directed edges with labels
61+
edge_labels = {(u, v, k): d['label'] for u, v, k, d in G.edges(keys=True, data=True)}
18462
for (u, v, key), label in edge_labels.items():
185-
x = (pos[u][0] * 0.6 + pos[v][0] * 0.4) # Adjust label position towards the source
186-
y = (pos[u][1] * 0.6 + pos[v][1] * 0.4)
187-
63+
print(u, v, key)
64+
rad = 0.15 + 0.05 * key * (-1 if key % 2 == 0 else 1)
65+
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], connectionstyle=f'arc3,rad={rad}',
66+
edge_color='gray', arrows=True, arrowsize=15, arrowstyle='->', width=1)
67+
x, y = pos[u][0] * 0.6 + pos[v][0] * 0.4, pos[u][1] * 0.6 + pos[v][1] * 0.4
18868
num_edges = len([k for (a, b, k) in edge_labels.keys() if a == u and b == v])
18969
if num_edges > 1:
19070
offset = (key - (num_edges - 1) / 2) * 0.05
19171
x += offset * (pos[v][0] - pos[u][0])
19272
y += offset * (pos[v][1] - pos[u][1])
19373

194-
plt.text(x, y, label, fontsize=6, ha='center', va='center',
195-
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
74+
plt.text(x, y, label, fontsize=6, ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
19675

19776
plt.axis('off')
19877
plt.tight_layout()

0 commit comments

Comments
 (0)