@@ -12,51 +12,36 @@ def assign_labels(
1212 rates : Optional [torch .Tensor ] = None ,
1313 alpha : float = 1.0 ,
1414) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
15- # language=rst
16- """
17- Assign labels to the neurons based on highest average spiking activity.
1815
19- :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single
20- layer's spiking activity.
21- :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to
22- spiking activity.
23- :param n_labels: The number of target labels in the data.
24- :param rates: If passed, these represent spike rates from a previous
25- ``assign_labels()`` call.
26- :param alpha: Rate of decay of label assignments.
27- :return: Tuple of class assignments, per-class spike proportions, and per-class
28- firing rates.
29- """
3016 n_neurons = spikes .size (2 )
17+ device = spikes .device # Keep everything on the same device.
3118
3219 if rates is None :
33- rates = torch .zeros ((n_neurons , n_labels ), device = spikes . device )
20+ rates = torch .zeros ((n_neurons , n_labels ), device = device )
3421
3522 # Sum over time dimension (spike ordering doesn't matter).
36- spikes = spikes .sum (1 )
37-
23+ summed_spikes = spikes .sum (1 )
24+
3825 for i in range (n_labels ):
3926 # Count the number of samples with this label.
40- n_labeled = torch .sum (labels == i ).float ()
27+ mask = (labels == i )
28+ n_labeled = mask .sum ().float ()
4129
4230 if n_labeled > 0 :
43- # Get indices of samples with this label.
44- indices = torch .nonzero (labels == i ).view (- 1 )
45-
46- # Compute average firing rates for this label.
47- selected_spikes = torch .index_select (
48- spikes , dim = 0 , index = torch .tensor (indices )
49- )
50- rates [:, i ] = alpha * rates [:, i ] + (
51- torch .sum (selected_spikes , 0 ) / n_labeled
52- )
31+ # Get indices of samples with this label (masking is faster and stays on the GPU).
32+ label_sum = summed_spikes [mask ].sum (0 )
33+ # Update rates.
34+ rates [:, i ] = alpha * rates [:, i ] + (label_sum / n_labeled )
5335
54- # Compute proportions of spike activity per class .
55- proportions = rates / rates .sum (1 , keepdim = True )
56- proportions [ proportions != proportions ] = 0 # Set NaNs to 0
36+ # 3. Compute proportions (and use 'torch.where' to avoid NaN bug) .
37+ total_activity = rates .sum (1 , keepdim = True )
38+ proportions = torch . where ( total_activity > 0 , rates / total_activity , torch . zeros_like ( rates ))
5739
5840 # Neuron assignments are the labels they fire most for.
59- assignments = torch .max (proportions , 1 )[1 ]
41+ max_vals , assignments = torch .max (proportions , 1 )
42+
43+ # Set unassigned (silent) neurons to -1 instead of defaulting to 0.
44+ assignments [max_vals == 0 ] = - 1
6045
6146 return assignments , proportions , rates
6247
0 commit comments