-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting_methods_compare.py
More file actions
106 lines (93 loc) · 4.79 KB
/
plotting_methods_compare.py
File metadata and controls
106 lines (93 loc) · 4.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import os
from typing import Dict
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from collections import defaultdict
import warnings
def get_values(file:str, data_dir:str,key_name = 'charts/episodic_return'):
event_acc = EventAccumulator(os.path.join(data_dir, file))
event_acc.Reload()
tags = event_acc.Tags()
if 'scalars' not in tags or key_name not in tags['scalars']:
raise KeyError(f"No {key_name} scalar found in {file}. Must be in {tags['scalars']}. Skipping.")
# print(np.array(event_acc.Scalars(key_name)).shape)
scalar_events = np.array([e.value for e in event_acc.Scalars(key_name)])
# scalar_events = np.array(event_acc.Scalars(key_name))
# print(scalar_events)
return scalar_events
def smooth(y, box_pts):
box = np.ones(box_pts)/box_pts
y_smooth = np.convolve(y, box, mode='same')
y_smooth[-box_pts:] = y[-box_pts:]
return y_smooth
# Helper to create nested dict
def nested_dict():
return defaultdict(nested_dict)
if __name__ == "__main__":
debug = False
# debug = True
warnings.filterwarnings('error')
runs_dir = f'{os.path.dirname(os.path.abspath(__file__))}/runs'
figs_dir = f'{os.path.dirname(os.path.abspath(__file__))}/figs/method_compare'
all_files = os.listdir(runs_dir)
method_compare_files = [f for f in all_files if 'Method' in f]
marked_method_files = {f:{'name':f,'env':f.split('__')[0],'method': f.split('Method')[0].split('__')[-1],
'state_only':('StateOnly' in f),'seed':int(f.split('__')[2])} for f in method_compare_files}
state_only_files = [f for f in method_compare_files if 'StateOnly' in f]
action_state_concat_files = [f for f in method_compare_files if not f in state_only_files]
env_names = set([f.split('__')[0] for f in method_compare_files])
method_names = set([marked_method_files[f]['method'] for f in marked_method_files.keys()])
# Build nested structure
marked_method_files = nested_dict()
for f in method_compare_files:
# Skip state_only files
if 'StateOnly' in f:
continue
# Parse filename
env = f.split('__')[0]
method = f.split('Method')[0].split('__')[-1]
seed = int(f.split('__')[2])
try:
rews = get_values(f,runs_dir)
except (KeyError,IndexError) as e:
if debug: print(f'Skipping {f} due to {e}')
continue
# Store in nested dict
marked_method_files[env][method][seed] = rews
plt.figure()
for ename in marked_method_files.keys():
for method in (marked_method_files[ename].keys()):
method_vals = marked_method_files[ename][method]
seeds = list(marked_method_files[ename][method].keys())
if debug: print(f'Processing {ename} {method} with seeds {seeds}')
# Compute mean and std across seeds
max_len = max([method_vals[s].shape[0] for s in seeds])
# Pad sequences to max_len
all_rews = np.array([np.pad(method_vals[s],(0,max_len - method_vals[s].shape[0]),mode='maximum') for s in seeds])
mean_rews = np.mean(all_rews, axis=0)
std_rews = np.std(all_rews, axis=0)
# Store back
marked_method_files[ename][method]['max_len'] = max_len
marked_method_files[ename][method]['mean'] = mean_rews
marked_method_files[ename][method]['std'] = std_rews
marked_method_files[ename][method]['best_seed'] = seeds[np.argmax([marked_method_files[ename][method][s].mean() for s in seeds])]
for ename in marked_method_files.keys():
colors = mpl.colormaps['tab10'](np.linspace(0,1,len(marked_method_files[ename].keys())))
plt.figure(figsize=(10,10))
end_cap = -1 if ename != 'Hopper-v4' else 4000
for i,method in enumerate(sorted(marked_method_files[ename].keys())):
zoom = marked_method_files[ename][method]['max_len'] * .7 if ename != 'Hopper-v4' else 0
mean_rews = marked_method_files[ename][method]['mean'][int(zoom):end_cap]
std_rews = marked_method_files[ename][method]['std'][int(zoom):end_cap]
iterations = np.arange(marked_method_files[ename][method]['mean'].shape[0])[int(zoom):end_cap]
plt.plot(iterations, mean_rews, label=method, color=colors[i],alpha=0.9)
plt.fill_between(iterations, np.clip(mean_rews - std_rews,a_min=0,a_max=None), mean_rews + std_rews, color=colors[i], alpha=0.3)
plt.title(f'Method Compare in {ename} State-Action Concat')
plt.xlabel('Iterations')
plt.ylabel('Episodic Return')
plt.grid()
plt.legend()
fig_name = f"methodCompare_{ename}_returns"
plt.savefig(f'{figs_dir}/{fig_name}.png')