Skip to content

Commit 97a356e

Browse files
authored
Merge pull request #71 from APLA-Toolbox/compare-astar-h
Add Heuristics comparing plot
2 parents fdbc60a + 60802b4 commit 97a356e

3 files changed

Lines changed: 55 additions & 17 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ Make sure you have a pddl-examples folder where you run your environment that co
100100
from jupyddl import DataAnalyst
101101

102102
da = DataAnalyst()
103-
da.plot_astar_data() # plots complexity statistics for all the problem.pddl/domain.pddl couples in the pddl-examples/ folder
103+
da.plot_astar() # plots complexity statistics for all the problem.pddl/domain.pddl couples in the pddl-examples/ folder
104104

105-
da.plot_astar_data(problem="pddl-examples/flip/problem.pddl", domain="pddl-examples/flip/domain.pddl") # scatter complexity statistics for the provided pddl
105+
da.plot_astar(problem="pddl-examples/flip/problem.pddl", domain="pddl-examples/flip/domain.pddl") # scatter complexity statistics for the provided pddl
106106

107-
da.plot_astar_data(heuristic_key="zero") # use h=0 instead of goal_count for your computation
107+
da.plot_astar(heuristic_key="zero") # use h=0 instead of goal_count for your computation
108108

109109
da.plot_dfs() # same as astar
110110

jupyddl/data_analyst.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
class DataAnalyst:
2020
def __init__(self):
2121
logging.info("Instantiating data analyst...")
22+
self.available_heuristics = ["goal_count", "zero"]
2223

2324
def __get_all_pddl_from_data(self):
2425
tested_files = []
@@ -49,7 +50,7 @@ def __plot_data(self, times, total_nodes, plot_title):
4950
times_y.append(data[node_opened])
5051
plt.plot(nodes_sorted, times_y, "r:o")
5152
plt.xlabel("Number of opened nodes")
52-
plt.ylabel("Planning computation time")
53+
plt.ylabel("Planning computation time (s)")
5354
plt.xscale("symlog")
5455
plt.title(plot_title)
5556
plt.grid(True)
@@ -58,7 +59,7 @@ def __plot_data(self, times, total_nodes, plot_title):
5859
def __scatter_data(self, times, total_nodes, plot_title):
5960
plt.scatter(total_nodes, times)
6061
plt.xlabel("Number of opened nodes")
61-
plt.ylabel("Planning computation time")
62+
plt.ylabel("Planning computation time (s)")
6263
plt.xscale("symlog")
6364
plt.title(plot_title)
6465
plt.grid(True)
@@ -110,7 +111,7 @@ def __gather_data_astar(
110111
return [total_time], [opened_nodes], has_multiple_files_tested
111112
return [0], [0], has_multiple_files_tested
112113

113-
def plot_astar_data(self, heuristic_key="goal_count", domain="", problem=""):
114+
def plot_astar(self, heuristic_key="goal_count", domain="", problem=""):
114115
if bool(not problem) != bool(not domain):
115116
logging.warning(
116117
"Either problem or domain wasn't provided, testing all files in data folder"
@@ -293,6 +294,34 @@ def __gather_data(
293294
xdata[name] = nodes
294295
return xdata, ydata
295296

297+
def comparative_astar_heuristic_plot(self, domain="", problem=""):
298+
_, ax = plt.subplots()
299+
plt.xlabel("Number of opened nodes")
300+
plt.ylabel("Planning computation time (s)")
301+
302+
for h in self.available_heuristics:
303+
times, nodes, _ = self.__gather_data_astar(domain_path=domain, problem_path=problem, heuristic_key=h)
304+
data = dict()
305+
for i, val in enumerate(nodes):
306+
data[val] = times[i]
307+
nodes_sorted = sorted(list(data.keys()))
308+
times_y = []
309+
for node_opened in nodes_sorted:
310+
times_y.append(data[node_opened])
311+
312+
ax.plot(
313+
nodes_sorted,
314+
times_y,
315+
"-o",
316+
label=h,
317+
)
318+
319+
plt.title("A* heuristics complexity comparison")
320+
plt.legend(loc="upper left")
321+
plt.xscale("symlog")
322+
plt.grid(True)
323+
plt.show(block=False)
324+
296325
def comparative_data_plot(
297326
self,
298327
astar=True,
@@ -341,15 +370,20 @@ def comparative_data_plot(
341370
with open("data.json") as fp:
342371
json_dict = json.load(fp)
343372

344-
fig, ax = plt.subplots()
345-
fig.set_figwidth(12)
346-
fig.set_figheight(6)
373+
_, ax = plt.subplots()
347374
plt.xlabel("Number of opened nodes")
348375
plt.ylabel("Planning computation time (s)")
349376
for planner in json_dict["xdata"].keys():
377+
data = dict()
378+
for i, val in enumerate(json_dict["xdata"][planner]):
379+
data[val] = json_dict["ydata"][planner][i]
380+
nodes_sorted = sorted(list(data.keys()))
381+
times_y = []
382+
for node_opened in nodes_sorted:
383+
times_y.append(data[node_opened])
350384
ax.plot(
351-
sorted(json_dict["xdata"][planner]),
352-
sorted(json_dict["ydata"][planner]),
385+
nodes_sorted,
386+
times_y,
353387
"-o",
354388
label=planner,
355389
)

tests/test_data_analyst.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ def test_data_analyst_constructor():
1111
_ = DataAnalyst()
1212
assert True
1313

14+
def test_heuristics_comparer():
15+
da = DataAnalyst()
16+
da.comparative_astar_heuristic_plot()
17+
18+
def test_heuristics_comparer_single():
19+
da = DataAnalyst()
20+
da.comparative_astar_heuristic_plot(domain="pddl-examples/flip/domain.pddl", problem="pddl-examples/flip/problem.pddl")
1421

1522
def test_data_analyst_plot_dfs_one_pddl():
1623
da = DataAnalyst()
@@ -41,10 +48,7 @@ def test_data_analyst_plot_dijkstra_one_pddl():
4148

4249
def test_data_analyst_plot_astar_h_goal_count_one_pddl():
4350
da = DataAnalyst()
44-
da.plot_astar_data(
45-
domain="pddl-examples/flip/domain.pddl",
46-
problem="pddl-examples/flip/problem.pddl",
47-
)
51+
da.plot_astar(domain="pddl-examples/flip/domain.pddl", problem="pddl-examples/flip/problem.pddl")
4852
assert True
4953

5054

@@ -68,13 +72,13 @@ def test_data_analyst_plot_dijkstra():
6872

6973
def test_data_analyst_plot_astar_h_goal_count():
7074
da = DataAnalyst()
71-
da.plot_astar_data()
75+
da.plot_astar()
7276
assert True
7377

7478

7579
def test_data_analyst_plot_astar_h_zero():
7680
da = DataAnalyst()
77-
da.plot_astar_data(heuristic_key="zero")
81+
da.plot_astar(heuristic_key="zero")
7882
assert True
7983

8084

0 commit comments

Comments
 (0)