Skip to content

Commit 21c13d3

Browse files
authored
Improve assign_labels accuracy and performance
Resolved a bug where silent neurons defaulted to the first class label. Optimized indexing to prevent unnecessary host-to-device transfers and added robust NaN handling for firing rate proportions.
1 parent 027689d commit 21c13d3

1 file changed

Lines changed: 17 additions & 32 deletions

File tree

bindsnet/evaluation/evaluation.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,51 +12,36 @@ def assign_labels(
1212
rates: Optional[torch.Tensor] = None,
1313
alpha: float = 1.0,
1414
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
15-
# language=rst
16-
"""
17-
Assign labels to the neurons based on highest average spiking activity.
1815

19-
:param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single
20-
layer's spiking activity.
21-
:param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to
22-
spiking activity.
23-
:param n_labels: The number of target labels in the data.
24-
:param rates: If passed, these represent spike rates from a previous
25-
``assign_labels()`` call.
26-
:param alpha: Rate of decay of label assignments.
27-
:return: Tuple of class assignments, per-class spike proportions, and per-class
28-
firing rates.
29-
"""
3016
n_neurons = spikes.size(2)
17+
device = spikes.device # Keep everything on the same device.
3118

3219
if rates is None:
33-
rates = torch.zeros((n_neurons, n_labels), device=spikes.device)
20+
rates = torch.zeros((n_neurons, n_labels), device=device)
3421

3522
# Sum over time dimension (spike ordering doesn't matter).
36-
spikes = spikes.sum(1)
37-
23+
summed_spikes = spikes.sum(1)
24+
3825
for i in range(n_labels):
3926
# Count the number of samples with this label.
40-
n_labeled = torch.sum(labels == i).float()
27+
mask = (labels == i)
28+
n_labeled = mask.sum().float()
4129

4230
if n_labeled > 0:
43-
# Get indices of samples with this label.
44-
indices = torch.nonzero(labels == i).view(-1)
45-
46-
# Compute average firing rates for this label.
47-
selected_spikes = torch.index_select(
48-
spikes, dim=0, index=torch.tensor(indices)
49-
)
50-
rates[:, i] = alpha * rates[:, i] + (
51-
torch.sum(selected_spikes, 0) / n_labeled
52-
)
31+
# Get indices of samples with this label (masking is faster and stays on the GPU).
32+
label_sum = summed_spikes[mask].sum(0)
33+
# Update rates.
34+
rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled)
5335

54-
# Compute proportions of spike activity per class.
55-
proportions = rates / rates.sum(1, keepdim=True)
56-
proportions[proportions != proportions] = 0 # Set NaNs to 0
36+
# 3. Compute proportions (and use 'torch.where' to avoid NaN bug).
37+
total_activity = rates.sum(1, keepdim=True)
38+
proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates))
5739

5840
# Neuron assignments are the labels they fire most for.
59-
assignments = torch.max(proportions, 1)[1]
41+
max_vals, assignments = torch.max(proportions, 1)
42+
43+
# Set unassigned (silent) neurons to -1 instead of defaulting to 0.
44+
assignments[max_vals == 0] = -1
6045

6146
return assignments, proportions, rates
6247

0 commit comments

Comments
 (0)