Skip to content

Commit 00d9196

Browse files
committed
Format files with black
2 parents dc2264a + d5c372c commit 00d9196

6 files changed

Lines changed: 76 additions & 51 deletions

File tree

jupyddl/data_analyst.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
import glob
33
import matplotlib as mpl
44
import logging
5+
56
mpl.use("TkAgg")
67
mpl.set_loglevel("WARNING")
78
import matplotlib.pyplot as plt
9+
810
plt.style.use("ggplot")
911
from .automated_planner import AutomatedPlanner
1012
from os import path
1113
import json
1214

15+
1316
class DataAnalyst:
1417
def __init__(self):
1518
logging.info("Instantiating data analyst...")
1619

17-
1820
def __get_all_pddl_from_data(self):
1921
tested_files = []
2022
domains_problems = []
@@ -29,29 +31,29 @@ def __get_all_pddl_from_data(self):
2931
i += 1
3032
return domains_problems
3133

32-
3334
def __plot_data(self, times, total_nodes, plot_title):
3435
plt.plot(total_nodes, times, "b:o")
3536
plt.xlabel("Number of opened nodes")
3637
plt.ylabel("Planning computation time")
3738
plt.title(plot_title)
38-
plt.xscale('symlog')
39-
plt.yscale('log')
39+
plt.xscale("symlog")
40+
plt.yscale("log")
4041
plt.grid(True)
4142
plt.show(block=False)
4243

43-
4444
def __scatter_data(self, times, total_nodes, plot_title):
4545
plt.scatter(total_nodes, times)
4646
plt.xlabel("Number of opened nodes")
4747
plt.ylabel("Planning computation time")
4848
plt.title(plot_title)
49-
plt.xscale('symlog')
50-
plt.yscale('log')
49+
plt.xscale("symlog")
50+
plt.yscale("log")
5151
plt.grid(True)
5252
plt.show(block=False)
5353

54-
def __gather_data_astar(self, domain_path="", problem_path="", heuristic_key="goal_count"):
54+
def __gather_data_astar(
55+
self, domain_path="", problem_path="", heuristic_key="goal_count"
56+
):
5557
has_multiple_files_tested = True
5658
if not domain_path and not problem_path:
5759
has_multiple_files_tested = False
@@ -90,7 +92,6 @@ def __gather_data_astar(self, domain_path="", problem_path="", heuristic_key="go
9092
return [0], [0], has_multiple_files_tested
9193
return [total_time], [total_nodes], has_multiple_files_tested
9294

93-
9495
def plot_astar_data(self, heuristic_key="goal_count", domain="", problem=""):
9596
if bool(not problem) != bool(not domain):
9697
logging.warning(
@@ -106,7 +107,6 @@ def plot_astar_data(self, heuristic_key="goal_count", domain="", problem=""):
106107
else:
107108
self.__scatter_data(times, total_nodes, title)
108109

109-
110110
def __gather_data_bfs(self, domain_path="", problem_path=""):
111111
has_multiple_files_tested = True
112112
if not domain_path and not problem_path:
@@ -130,7 +130,6 @@ def __gather_data_bfs(self, domain_path="", problem_path=""):
130130
_, total_time, opened_nodes = apla.breadth_first_search()
131131
return [total_time], [total_nodes], has_multiple_files_tested
132132

133-
134133
def plot_bfs(self, domain="", problem=""):
135134
title = "BFS Statistics"
136135
if bool(not problem) != bool(not domain):
@@ -146,7 +145,6 @@ def plot_bfs(self, domain="", problem=""):
146145
else:
147146
self.__scatter_data(times, total_nodes, title)
148147

149-
150148
def __gather_data_dfs(self, domain_path="", problem_path=""):
151149
has_multiple_files_tested = True
152150
if not domain_path and not problem_path:
@@ -170,7 +168,6 @@ def __gather_data_dfs(self, domain_path="", problem_path=""):
170168
_, total_time, opened_nodes = apla.depth_first_search()
171169
return [total_time], [total_nodes], has_multiple_files_tested
172170

173-
174171
def plot_dfs(self, problem="", domain=""):
175172
title = "DFS Statistics"
176173
if bool(not problem) != bool(not domain):
@@ -209,7 +206,6 @@ def __gather_data_dijkstra(self, domain_path="", problem_path=""):
209206
_, total_time, opened_nodes = apla.dijktra_best_first_search()
210207
return [total_time], [total_nodes], has_multiple_files_tested
211208

212-
213209
def plot_dijkstra(self, problem="", domain=""):
214210
title = "Dijkstra Statistics"
215211
if bool(not problem) != bool(not domain):
@@ -225,7 +221,6 @@ def plot_dijkstra(self, problem="", domain=""):
225221
else:
226222
self.__scatter_data(times, total_nodes, title)
227223

228-
229224
def __gather_data(
230225
self,
231226
heuristic_key="goal_count",
@@ -239,7 +234,7 @@ def __gather_data(
239234
gatherers = []
240235
xdata = dict()
241236
ydata = dict()
242-
237+
243238
if bfs:
244239
gatherers.append(("BFS", self.__gather_data_bfs))
245240
if dfs:
@@ -249,19 +244,22 @@ def __gather_data(
249244
if astar:
250245
gatherers.append(("A*", self.__gather_data_astar))
251246

252-
_, _, _ = self.__gather_data_bfs(domain_path=domain, problem_path=problem) # Dummy line to do first parsing and get rid of static loading
247+
_, _, _ = self.__gather_data_bfs(
248+
domain_path=domain, problem_path=problem
249+
) # Dummy line to do first parsing and get rid of static loading
253250
for name, g in gatherers:
254251
if g == self.__gather_data_astar:
255252
times, nodes, _ = self.__gather_data_astar(
256-
domain_path=domain, problem_path=problem, heuristic_key=heuristic_key
253+
domain_path=domain,
254+
problem_path=problem,
255+
heuristic_key=heuristic_key,
257256
)
258257
else:
259258
times, nodes, _ = g(domain_path=domain, problem_path=problem)
260259
ydata[name] = times
261260
xdata[name] = nodes
262261
return xdata, ydata
263262

264-
265263
def comparative_data_plot(
266264
self,
267265
astar=True,
@@ -317,11 +315,14 @@ def comparative_data_plot(
317315
plt.ylabel("Planning computation time (s)")
318316
for planner in json_dict["xdata"].keys():
319317
ax.plot(
320-
sorted(json_dict["xdata"][planner]), sorted(json_dict["ydata"][planner]), '-o', label=planner
318+
sorted(json_dict["xdata"][planner]),
319+
sorted(json_dict["ydata"][planner]),
320+
"-o",
321+
label=planner,
321322
)
322323
plt.title("Planners complexity comparison")
323324
plt.legend(loc="upper left")
324-
plt.xscale('symlog')
325-
plt.yscale('log')
325+
plt.xscale("symlog")
326+
plt.yscale("log")
326327
plt.grid(True)
327328
plt.show(block=False)

scripts/astar_data_plot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from common import get_all_pddl_from_data, plot_data, scatter_data
22
from os import path
33
import sys
4+
45
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
56
import logging
67
from jupyddl.automated_planner import AutomatedPlanner

scripts/bfs_data_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ def plot_bfs(domain="", problem=""):
4848

4949

5050
if __name__ == "__main__":
51-
plot_bfs()
51+
plot_bfs()

scripts/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def plot_data(times, total_nodes, plot_title):
2929
plt.xlabel("Number of opened nodes")
3030
plt.ylabel("Planning computation time")
3131
plt.title(plot_title)
32-
plt.xscale('symlog')
33-
plt.yscale('log')
32+
plt.xscale("symlog")
33+
plt.yscale("log")
3434
plt.grid(True)
3535
plt.show()
3636

@@ -40,8 +40,8 @@ def scatter_data(times, total_nodes, plot_title):
4040
plt.xlabel("Number of opened nodes")
4141
plt.ylabel("Planning computation time")
4242
plt.title(plot_title)
43-
plt.xscale('symlog')
44-
plt.yscale('log')
43+
plt.xscale("symlog")
44+
plt.yscale("log")
4545
plt.grid(True)
4646
plt.show()
4747

scripts/planners_data_plot.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def gather_data(
2020
gatherers = []
2121
xdata = dict()
2222
ydata = dict()
23-
23+
2424
if bfs:
2525
gatherers.append(("BFS", gather_data_bfs))
2626
if dfs:
@@ -30,7 +30,9 @@ def gather_data(
3030
if astar:
3131
gatherers.append(("A*", gather_data_astar))
3232

33-
_, _, _ = gather_data_bfs(domain_path=domain, problem_path=problem) # Dummy line to do first parsing and get rid of static loading
33+
_, _, _ = gather_data_bfs(
34+
domain_path=domain, problem_path=problem
35+
) # Dummy line to do first parsing and get rid of static loading
3436
for name, g in gatherers:
3537
if g == gather_data_astar:
3638
times, nodes, _ = gather_data_astar(
@@ -97,12 +99,15 @@ def comparative_data_plot(
9799
plt.ylabel("Planning computation time (s)")
98100
for planner in json_dict["xdata"].keys():
99101
ax.plot(
100-
sorted(json_dict["xdata"][planner]), sorted(json_dict["ydata"][planner]), '-o', label=planner
102+
sorted(json_dict["xdata"][planner]),
103+
sorted(json_dict["ydata"][planner]),
104+
"-o",
105+
label=planner,
101106
)
102107
plt.title("Planners complexity comparison")
103108
plt.legend(loc="upper left")
104-
plt.xscale('symlog')
105-
plt.yscale('log')
109+
plt.xscale("symlog")
110+
plt.yscale("log")
106111
plt.grid(True)
107112
plt.show()
108113

tests/test_data_analyst.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,73 +6,91 @@
66
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
77
from jupyddl.data_analyst import DataAnalyst
88

9+
910
def test_data_analyst_constructor():
1011
_ = DataAnalyst()
11-
assert(True)
12+
assert True
13+
1214

1315
def test_data_analyst_plot_dfs():
1416
da = DataAnalyst()
1517
da.plot_dfs()
16-
assert(True)
18+
assert True
19+
1720

1821
def test_data_analyst_plot_bfs():
1922
da = DataAnalyst()
2023
da.plot_bfs()
21-
assert(True)
24+
assert True
25+
2226

2327
def test_data_analyst_plot_dijkstra():
2428
da = DataAnalyst()
2529
da.plot_dijkstra()
26-
assert(True)
30+
assert True
31+
2732

2833
def test_data_analyst_plot_astar_h_goal_count():
2934
da = DataAnalyst()
3035
da.plot_astar_data()
31-
assert(True)
36+
assert True
37+
3238

3339
def test_data_analyst_plot_astar_h_zero():
3440
da = DataAnalyst()
3541
da.plot_astar_data(heuristic_key="zero")
36-
assert(True)
42+
assert True
43+
3744

3845
def test_comparative_no_restrictions():
3946
da = DataAnalyst()
4047
da.comparative_data_plot()
41-
assert(True)
48+
assert True
49+
4250

4351
def test_comparative_no_astar():
4452
da = DataAnalyst()
4553
da.comparative_data_plot(astar=False)
46-
assert(True)
54+
assert True
55+
4756

4857
def test_comparative_no_bfs():
4958
da = DataAnalyst()
5059
da.comparative_data_plot(bfs=False)
51-
assert(True)
60+
assert True
61+
5262

5363
def test_comparative_no_dijkstra():
5464
da = DataAnalyst()
5565
da.comparative_data_plot(dijkstra=False)
56-
assert(True)
66+
assert True
67+
5768

5869
def test_comparative_no_dfs():
5970
da = DataAnalyst()
6071
da.comparative_data_plot(dfs=False)
61-
assert(True)
72+
assert True
73+
6274

6375
def test_comparative_one_pddl():
6476
da = DataAnalyst()
65-
da.comparative_data_plot(dfs=False, bfs=False, domain="data/domain.pddl", problem="data/problem.pddl")
66-
assert(True)
77+
da.comparative_data_plot(
78+
dfs=False, bfs=False, domain="data/domain.pddl", problem="data/problem.pddl"
79+
)
80+
assert True
81+
6782

6883
def test_comparative_use_data_json():
6984
da = DataAnalyst()
70-
da.comparative_data_plot(domain="data/domain.pddl", problem="data/problem.pddl", collect_new_data=False)
71-
assert(True)
85+
da.comparative_data_plot(
86+
domain="data/domain.pddl", problem="data/problem.pddl", collect_new_data=False
87+
)
88+
assert True
89+
7290

7391
def test_comparative_zero_h():
7492
da = DataAnalyst()
75-
da.comparative_data_plot(domain="data/domain.pddl", problem="data/problem.pddl", heuristic_key="zero")
76-
assert(True)
77-
78-
93+
da.comparative_data_plot(
94+
domain="data/domain.pddl", problem="data/problem.pddl", heuristic_key="zero"
95+
)
96+
assert True

0 commit comments

Comments
 (0)