Skip to content

Commit 9febfff

Browse files
authored
Reverted variable names to original
Reverted variable names to original.
1 parent 21c13d3 commit 9febfff

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

bindsnet/evaluation/evaluation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,26 @@ def assign_labels(
1414
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1515

1616
n_neurons = spikes.size(2)
17-
device = spikes.device # Keep everything on the same device.
1817

1918
if rates is None:
20-
rates = torch.zeros((n_neurons, n_labels), device=device)
19+
rates = torch.zeros((n_neurons, n_labels), device=spikes.device)
2120

2221
# Sum over time dimension (spike ordering doesn't matter).
23-
summed_spikes = spikes.sum(1)
22+
spikes = spikes.sum(1)
2423

2524
for i in range(n_labels):
26-
# Count the number of samples with this label.
25+
# Create mask (faster and allows future steps to stay on GPU).
2726
mask = (labels == i)
27+
# Count the number of samples with this label.
2828
n_labeled = mask.sum().float()
2929

3030
if n_labeled > 0:
3131
# Get indices of samples with this label (masking is faster and stays on the GPU).
32-
label_sum = summed_spikes[mask].sum(0)
32+
label_sum = spikes[mask].sum(0)
3333
# Update rates.
3434
rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled)
3535

36-
# 3. Compute proportions (and use 'torch.where' to avoid NaN bug).
36+
# Compute proportions (and use 'torch.where' to avoid NaN bug).
3737
total_activity = rates.sum(1, keepdim=True)
3838
proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates))
3939

0 commit comments

Comments
 (0)