@@ -37,27 +37,37 @@ def assign_labels(
3737 spikes = spikes .sum (1 )
3838
3939 for i in range (n_labels ):
40- # Create mask (faster and allows future steps to stay on GPU) .
40+ # Create mask.
4141 mask = (labels == i )
4242 # Count the number of samples with this label.
4343 n_labeled = mask .sum ().float ()
4444
4545 if n_labeled > 0 :
46- # Get indices of samples with this label (masking is faster and stays on the GPU) .
46+ # Get indices of samples with this label.
4747 label_sum = spikes [mask ].sum (0 )
4848 # Update rates.
4949 rates [:, i ] = alpha * rates [:, i ] + (label_sum / n_labeled )
5050
51- # Compute proportions (and use 'torch.where' to avoid NaN bug) .
51+ # Compute proportions of spike activity per class .
5252 total_activity = rates .sum (1 , keepdim = True )
5353 proportions = torch .where (total_activity > 0 , rates / total_activity , torch .zeros_like (rates ))
5454
55- # Neuron assignments are the labels they fire most for.
56- max_vals , assignments = torch .max (proportions , 1 )
55+ # Noise for random tie breaking.
56+ eps = 1e-6 # Small enough not to distort real decisions
57+ noise = eps * torch .randn_like (proportions )
5758
58- # Set unassigned (silent) neurons to -1 instead of defaulting to 0 .
59- assignments [ max_vals == 0 ] = - 1
59+ # Neuron assignments are the labels they fire most for .
60+ assignments = torch . argmax ( proportions + noise , dim = 1 )
6061
62+ # Uniform assignment for silent neurons
63+ silent_mask = total_activity .squeeze () == 0
64+ n_silent = silent_mask .sum ()
65+
66+ if n_silent > 0 :
67+ assignments [silent_mask ] = torch .randint (
68+ 0 , n_labels , (n_silent ,), device = spikes .device
69+ )
70+
6171 return assignments , proportions , rates
6272
6373
0 commit comments