File tree Expand file tree Collapse file tree
mambular/arch_utils/layer_utils Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -86,7 +86,7 @@ def backward(ctx, grad_output): # type: ignore
8686 return grad_input , None
8787
8888 @staticmethod
89- def _threshold_and_support (x , dim = - 1 ):
89+ def _threshold_and_support (input , dim = - 1 ):
9090 """
9191 Computes the threshold and support for sparsemax.
9292
@@ -103,14 +103,14 @@ def _threshold_and_support(x, dim=-1):
103103 - torch.Tensor : The threshold value for sparsemax.
104104 - torch.Tensor : The support size tensor.
105105 """
106- input_srt , _ = torch .sort (x , descending = True , dim = dim )
106+ input_srt , _ = torch .sort (input , descending = True , dim = dim )
107107 input_cumsum = input_srt .cumsum (dim ) - 1
108108 rhos = _make_ix_like (input , dim )
109109 support = rhos * input_srt > input_cumsum
110110
111111 support_size = support .sum (dim = dim ).unsqueeze (dim )
112112 tau = input_cumsum .gather (dim , support_size - 1 )
113- tau /= support_size .to (x .dtype )
113+ tau /= support_size .to (input .dtype )
114114 return tau , support_size
115115
116116
You can’t perform that action at this time.
0 commit comments