@@ -14,26 +14,26 @@ def assign_labels(
1414) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1515
1616 n_neurons = spikes .size (2 )
17- device = spikes .device # Keep everything on the same device.
1817
1918 if rates is None :
20- rates = torch .zeros ((n_neurons , n_labels ), device = device )
19+ rates = torch .zeros ((n_neurons , n_labels ), device = spikes . device )
2120
2221 # Sum over time dimension (spike ordering doesn't matter).
23- summed_spikes = spikes .sum (1 )
22+ spikes = spikes .sum (1 )
2423
2524 for i in range (n_labels ):
26- # Count the number of samples with this label .
25+ # Create mask (faster and allows future steps to stay on GPU) .
2726 mask = (labels == i )
27+ # Count the number of samples with this label.
2828 n_labeled = mask .sum ().float ()
2929
3030 if n_labeled > 0 :
3131 # Get indices of samples with this label (masking is faster and stays on the GPU).
32- label_sum = summed_spikes [mask ].sum (0 )
32+ label_sum = spikes [mask ].sum (0 )
3333 # Update rates.
3434 rates [:, i ] = alpha * rates [:, i ] + (label_sum / n_labeled )
3535
36- # 3. Compute proportions (and use 'torch.where' to avoid NaN bug).
36+ # Compute proportions (and use 'torch.where' to avoid NaN bug).
3737 total_activity = rates .sum (1 , keepdim = True )
3838 proportions = torch .where (total_activity > 0 , rates / total_activity , torch .zeros_like (rates ))
3939
0 commit comments