-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConfusion_graph_EMG.py
More file actions
161 lines (144 loc) · 6.62 KB
/
Confusion_graph_EMG.py
File metadata and controls
161 lines (144 loc) · 6.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
MGPT_BALL_MONKEYS = {"Jaco", "Theo"} # monkeys that did mgpt / ball
ISO_WM_SPR_MONKEYS = {"Jango", "JacB"} # monkeys that did iso / wm / spr
# ───────────────────────── UPDATED HELPERS ─────────────────────────
def _apply_family_filter(df, tasks):
"""Return df restricted to the monkeys that actually did those tasks."""
if set(tasks) <= {"mgpt", "ball"}:
keep = MGPT_BALL_MONKEYS
elif set(tasks) <= {"iso", "wm", "spr"}:
keep = ISO_WM_SPR_MONKEYS
else: # mixed list – keep everything
return df
return df[df['train_monkey'].isin(keep)]
def _heatmap(piv, title, fname):
plt.figure(figsize=(1.5 + 1.6*piv.shape[1], 1.3 + 0.9*piv.shape[0]))
sns.heatmap(piv, annot=True, fmt=".2f",
cmap="viridis", vmin=0, vmax=1,
cbar_kws={"label": "mean_VAF"})
plt.title(title)
plt.xlabel("Test Task")
plt.ylabel(piv.index.name.capitalize())
plt.tight_layout()
plt.savefig(fname, dpi=700)
plt.close()
print(f"[INFO] saved → {fname}")
def _base(df, tasks, alignment_mode, decoder):
"""Return dataframe filtered by test‑tasks & alignment/decoder labels."""
df_f = df[df['test_task'].isin(tasks)]
if alignment_mode is not None:
df_f = df_f[df_f['alignment_mode'] == alignment_mode]
if decoder is not None:
df_f = df_f[df_f['decoder_type'] == decoder]
return df_f
def plot_confusion_mgpt_ball(df, metric='mean_VAF', *,
alignment_mode=None,
decoder=None,
group_by=None, # None | 'decoder' | 'monkey'
grouped=False):
tasks = ["mgpt", "ball"]
df_f = _base(df, tasks, alignment_mode, decoder)
df_f = _apply_family_filter(df_f, tasks)
if df_f.empty:
print("[WARN] no mgpt/ball rows match current filters")
return
# ------------ rows = decoders -----------------------------------
if group_by == "decoder":
piv = (df_f.groupby(['decoder_type', 'test_task'])[metric]
.mean().unstack()
.reindex(index=["GRU","LSTM","LIN","LiGRU"]))
_heatmap(piv, "mgpt / ball – decoders as rows",
"mgpt_ball_rows=decoder.png")
return
# ------------ rows = monkeys ------------------------------------
if group_by == "monkey":
if grouped: # one PNG per monkey
for mky, blk in df_f.groupby('train_monkey'):
piv = (blk.groupby('test_task')[metric]
.mean().to_frame(mky).T)
_heatmap(piv,
f"{mky}: mgpt / ball – mean over decoders",
f"mgpt_ball_{mky}.png")
else: # all monkeys together
piv = (df_f.groupby('test_task')[metric]
.mean().to_frame("all").T)
_heatmap(piv,
"mgpt / ball – mean over decoders & monkeys",
"mgpt_ball_ALL.png")
return
# ------------ original per‑decoder single‑row version -----------
if decoder is None:
print("[WARN] decoder=None with group_by=None – nothing plotted")
else:
piv = (df_f.groupby(['train_task','test_task'])[metric]
.mean().unstack())
_heatmap(piv,
f"{decoder}: mgpt / ball",
f"mgpt_ball_{decoder}.png")
def plot_confusion_iso_wm_spr(df, metric='mean_VAF', *,
alignment_mode=None,
decoder=None,
group_by=None, # None | 'decoder' | 'monkey'
grouped=False):
tasks = ["iso", "wm", "spr"]
df_f = _base(df, tasks, alignment_mode, decoder)
df_f = _apply_family_filter(df_f, tasks)
if df_f.empty:
print("[WARN] no iso/wm/spr rows match current filters")
return
if group_by == "decoder":
piv = (df_f.groupby(['decoder_type', 'test_task'])[metric]
.mean().unstack()
.reindex(index=["GRU","LSTM","LIN","LiGRU"]))
_heatmap(piv, "iso / wm / spr – decoders as rows",
"iso_wm_spr_rows=decoder.png")
return
if group_by == "monkey":
if grouped:
for mky, blk in df_f.groupby('train_monkey'):
piv = (blk.groupby('test_task')[metric]
.mean().to_frame(mky).T)
_heatmap(piv,
f"{mky}: iso / wm / spr – mean over decoders",
f"iso_wm_spr_{mky}.png")
else:
piv = (df_f.groupby('test_task')[metric]
.mean().to_frame("all").T)
_heatmap(piv,
"iso / wm / spr – mean over decoders & monkeys",
"iso_wm_spr_ALL.png")
return
if decoder is None:
print("[WARN] decoder=None with group_by=None – nothing plotted")
else:
piv = (df_f.groupby(['train_task','test_task'])[metric]
.mean().unstack())
_heatmap(piv,
f"{decoder}: iso / wm / spr",
f"iso_wm_spr_{decoder}.png")
df = pd.read_pickle("df_results_emg_validation_whithin.pkl")
align_type = "recalculated"
# 1) old behaviour: one PNG per decoder
for dec in ["GRU","LSTM","Linear","LiGRU"]:
plot_confusion_mgpt_ball(df, alignment_mode=align_type, decoder=dec)
plot_confusion_iso_wm_spr(df, alignment_mode=align_type, decoder=dec)
# 2) decoders as **rows** (4×N)
# plot_confusion_mgpt_ball(df, alignment_mode=align_type,
# decoder=None, group_by="decoder")
# plot_confusion_iso_wm_spr(df, alignment_mode=align_type,
# decoder=None, group_by="decoder")
# # 3) mean of all decoders, **one PNG per monkey**
# plot_confusion_mgpt_ball(df, alignment_mode=align_type,
# decoder=None, group_by="monkey", grouped=True)
# plot_confusion_iso_wm_spr(df, alignment_mode=align_type,
# decoder=None, group_by="monkey", grouped=True)
# # 4) mean of all decoders, all monkeys together
# plot_confusion_mgpt_ball(df, alignment_mode=align_type,
# decoder=None, group_by="monkey", grouped=False)
# plot_confusion_iso_wm_spr(df, alignment_mode=align_type,
# decoder=None, group_by="monkey", grouped=False)