Skip to content

Commit 5f940d7

Browse files
authored
Improve evaluation function with noise for assignments
Switched from '-1' sentinel value for silent neurons to randomly assigned labels. Also added noise for random tie-breaking in neuron assignments.
1 parent a2cbace commit 5f940d7

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

bindsnet/evaluation/evaluation.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,37 @@ def assign_labels(
3737
spikes = spikes.sum(1)
3838

3939
for i in range(n_labels):
40-
# Create mask (faster and allows future steps to stay on GPU).
40+
# Create mask.
4141
mask = (labels == i)
4242
# Count the number of samples with this label.
4343
n_labeled = mask.sum().float()
4444

4545
if n_labeled > 0:
46-
# Get indices of samples with this label (masking is faster and stays on the GPU).
46+
# Get indices of samples with this label.
4747
label_sum = spikes[mask].sum(0)
4848
# Update rates.
4949
rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled)
5050

51-
# Compute proportions (and use 'torch.where' to avoid NaN bug).
51+
# Compute proportions of spike activity per class.
5252
total_activity = rates.sum(1, keepdim=True)
5353
proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates))
5454

55-
# Neuron assignments are the labels they fire most for.
56-
max_vals, assignments = torch.max(proportions, 1)
55+
# Noise for random tie breaking.
56+
eps = 1e-6 # Small enough not to distort real decisions
57+
noise = eps * torch.randn_like(proportions)
5758

58-
# Set unassigned (silent) neurons to -1 instead of defaulting to 0.
59-
assignments[max_vals == 0] = -1
59+
# Neuron assignments are the labels they fire most for.
60+
assignments = torch.argmax(proportions + noise, dim=1)
6061

62+
# Uniform assignment for silent neurons
63+
silent_mask = total_activity.squeeze() == 0
64+
n_silent = silent_mask.sum()
65+
66+
if n_silent > 0:
67+
assignments[silent_mask] = torch.randint(
68+
0, n_labels, (n_silent,), device=spikes.device
69+
)
70+
6171
return assignments, proportions, rates
6272

6373

0 commit comments

Comments
 (0)