Skip to content

Commit bf01626

Browse files
guilyxactions-user
authored andcommitted
Apply formatting changes
1 parent 0419dae commit bf01626

2 files changed

Lines changed: 55 additions & 20 deletions

File tree

jupyddl/data_analyst.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __get_all_pddl_from_data(self, max_pddl_instances=-1):
3838
if i % 2 != 0:
3939
domains_problems.append((tested_files[i - 1], tested_files[i]))
4040
i += 1
41-
if max_pddl_instances != -1 and i >= max_pddl_instances*2:
41+
if max_pddl_instances != -1 and i >= max_pddl_instances * 2:
4242
return domains_problems
4343
return domains_problems
4444
return [
@@ -72,12 +72,18 @@ def __scatter_data(self, times, total_nodes, plot_title):
7272
plt.show(block=False)
7373

7474
def __gather_data_astar(
75-
self, domain_path="", problem_path="", heuristic_key="goal_count", max_pddl_instances=-1
75+
self,
76+
domain_path="",
77+
problem_path="",
78+
heuristic_key="goal_count",
79+
max_pddl_instances=-1,
7680
):
7781
has_multiple_files_tested = True
7882
if not domain_path or not problem_path:
7983
metrics = dict()
80-
for problem, domain in self.__get_all_pddl_from_data(max_pddl_instances=max_pddl_instances):
84+
for problem, domain in self.__get_all_pddl_from_data(
85+
max_pddl_instances=max_pddl_instances
86+
):
8187
logging.debug("Loading new PDDL instance planned with A*...")
8288
logging.debug("Domain: " + domain)
8389
logging.debug("Problem: " + problem)
@@ -117,14 +123,19 @@ def __gather_data_astar(
117123
return [total_time], [opened_nodes], has_multiple_files_tested
118124
return [0], [0], has_multiple_files_tested
119125

120-
def plot_astar(self, heuristic_key="goal_count", domain="", problem="", max_pddl_instances=-1):
126+
def plot_astar(
127+
self, heuristic_key="goal_count", domain="", problem="", max_pddl_instances=-1
128+
):
121129
if bool(not problem) != bool(not domain):
122130
logging.warning(
123131
"Either problem or domain wasn't provided, testing all files in data folder"
124132
)
125133
problem = domain = ""
126134
times, total_nodes, has_multiple_files_tested = self.__gather_data_astar(
127-
heuristic_key=heuristic_key, problem_path=problem, domain_path=domain, max_pddl_instances=max_pddl_instances
135+
heuristic_key=heuristic_key,
136+
problem_path=problem,
137+
domain_path=domain,
138+
max_pddl_instances=max_pddl_instances,
128139
)
129140
title = "A* Statistics" + "[Heuristic: " + heuristic_key + "]"
130141
if has_multiple_files_tested:
@@ -136,7 +147,9 @@ def __gather_data_bfs(self, domain_path="", problem_path="", max_pddl_instances=
136147
has_multiple_files_tested = True
137148
if not domain_path or not problem_path:
138149
metrics = dict()
139-
for problem, domain in self.__get_all_pddl_from_data(max_pddl_instances=max_pddl_instances):
150+
for problem, domain in self.__get_all_pddl_from_data(
151+
max_pddl_instances=max_pddl_instances
152+
):
140153
logging.debug("Loading new PDDL instance planned with BFS...")
141154
logging.debug("Domain: " + domain)
142155
logging.debug("Problem: " + problem)
@@ -168,7 +181,9 @@ def plot_bfs(self, domain="", problem="", max_pddl_instances=-1):
168181
)
169182
problem = domain = ""
170183
times, total_nodes, has_multiple_files_tested = self.__gather_data_bfs(
171-
problem_path=problem, domain_path=domain, max_pddl_instances=max_pddl_instances
184+
problem_path=problem,
185+
domain_path=domain,
186+
max_pddl_instances=max_pddl_instances,
172187
)
173188
if has_multiple_files_tested:
174189
self.__plot_data(times, total_nodes, title)
@@ -179,7 +194,9 @@ def __gather_data_dfs(self, domain_path="", problem_path="", max_pddl_instances=
179194
has_multiple_files_tested = True
180195
if not domain_path or not problem_path:
181196
metrics = dict()
182-
for problem, domain in self.__get_all_pddl_from_data(max_pddl_instances=max_pddl_instances):
197+
for problem, domain in self.__get_all_pddl_from_data(
198+
max_pddl_instances=max_pddl_instances
199+
):
183200
logging.debug("Loading new PDDL instance planned with DFS...")
184201
logging.debug("Domain: " + domain)
185202
logging.debug("Problem: " + problem)
@@ -211,18 +228,24 @@ def plot_dfs(self, problem="", domain="", max_pddl_instances=-1):
211228
)
212229
problem = domain = ""
213230
times, total_nodes, has_multiple_files_tested = self.__gather_data_dfs(
214-
problem_path=problem, domain_path=domain, max_pddl_instances=max_pddl_instances
231+
problem_path=problem,
232+
domain_path=domain,
233+
max_pddl_instances=max_pddl_instances,
215234
)
216235
if has_multiple_files_tested:
217236
self.__plot_data(times, total_nodes, title)
218237
else:
219238
self.__scatter_data(times, total_nodes, title)
220239

221-
def __gather_data_dijkstra(self, domain_path="", problem_path="", max_pddl_instances=-1):
240+
def __gather_data_dijkstra(
241+
self, domain_path="", problem_path="", max_pddl_instances=-1
242+
):
222243
has_multiple_files_tested = True
223244
if not domain_path or not problem_path:
224245
metrics = dict()
225-
for problem, domain in self.__get_all_pddl_from_data(max_pddl_instances=max_pddl_instances):
246+
for problem, domain in self.__get_all_pddl_from_data(
247+
max_pddl_instances=max_pddl_instances
248+
):
226249
logging.debug("Loading new PDDL instance planned with Dijkstra...")
227250
logging.debug("Domain: " + domain)
228251
logging.debug("Problem: " + problem)
@@ -254,7 +277,9 @@ def plot_dijkstra(self, problem="", domain="", max_pddl_instances=-1):
254277
)
255278
problem = domain = ""
256279
times, total_nodes, has_multiple_files_tested = self.__gather_data_dijkstra(
257-
problem_path=problem, domain_path=domain, max_pddl_instances=max_pddl_instances
280+
problem_path=problem,
281+
domain_path=domain,
282+
max_pddl_instances=max_pddl_instances,
258283
)
259284
if has_multiple_files_tested:
260285
self.__plot_data(times, total_nodes, title)
@@ -270,7 +295,7 @@ def __gather_data(
270295
dijkstra=True,
271296
domain="",
272297
problem="",
273-
max_pddl_instances=-1
298+
max_pddl_instances=-1,
274299
):
275300
gatherers = []
276301
xdata = dict()
@@ -294,22 +319,31 @@ def __gather_data(
294319
domain_path=domain,
295320
problem_path=problem,
296321
heuristic_key=heuristic_key,
297-
max_pddl_instances=max_pddl_instances
322+
max_pddl_instances=max_pddl_instances,
298323
)
299324
else:
300-
times, nodes, _ = g(domain_path=domain, problem_path=problem, max_pddl_instances=max_pddl_instances)
325+
times, nodes, _ = g(
326+
domain_path=domain,
327+
problem_path=problem,
328+
max_pddl_instances=max_pddl_instances,
329+
)
301330
ydata[name] = times
302331
xdata[name] = nodes
303332
return xdata, ydata
304333

305-
def comparative_astar_heuristic_plot(self, domain="", problem="", max_pddl_instances=-1):
334+
def comparative_astar_heuristic_plot(
335+
self, domain="", problem="", max_pddl_instances=-1
336+
):
306337
_, ax = plt.subplots()
307338
plt.xlabel("Number of opened nodes")
308339
plt.ylabel("Planning computation time (s)")
309340

310341
for h in self.available_heuristics:
311342
times, nodes, _ = self.__gather_data_astar(
312-
domain_path=domain, problem_path=problem, heuristic_key=h, max_pddl_instances=max_pddl_instances
343+
domain_path=domain,
344+
problem_path=problem,
345+
heuristic_key=h,
346+
max_pddl_instances=max_pddl_instances,
313347
)
314348
data = dict()
315349
for i, val in enumerate(nodes):
@@ -342,7 +376,7 @@ def comparative_data_plot(
342376
problem="",
343377
heuristic_key="goal_count",
344378
collect_new_data=True,
345-
max_pddl_instances=-1
379+
max_pddl_instances=-1,
346380
):
347381
json_dict = {}
348382
if collect_new_data:
@@ -354,7 +388,7 @@ def comparative_data_plot(
354388
dijkstra=dijkstra,
355389
domain=domain,
356390
problem=problem,
357-
max_pddl_instances=max_pddl_instances
391+
max_pddl_instances=max_pddl_instances,
358392
)
359393
json_dict["xdata"] = xdata
360394
json_dict["ydata"] = ydata
@@ -373,7 +407,7 @@ def comparative_data_plot(
373407
dijkstra=dijkstra,
374408
domain=domain,
375409
problem=problem,
376-
max_pddl_instances=max_pddl_instances
410+
max_pddl_instances=max_pddl_instances,
377411
)
378412
json_dict["xdata"] = xdata
379413
json_dict["ydata"] = ydata

tests/test_data_analyst.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_data_analyst_plot_astar_h_goal_count():
8484
da.plot_astar()
8585
assert True
8686

87+
8788
def test_data_analyst_plot_dfs_restricted():
8889
da = DataAnalyst()
8990
da.plot_dfs(max_pddl_instances=2)

0 commit comments

Comments
 (0)