@@ -36,13 +36,13 @@ class SparsemaxFunction(Function):
3636 """
3737
3838 @staticmethod
39- def forward (ctx , x , dim = - 1 ):
39+ def forward (ctx , input , dim = - 1 ):
4040 """
4141 Forward pass of sparsemax: a normalizing, sparse transformation.
4242
4343 Parameters
4444 ----------
45- x : torch.Tensor
45+ input : torch.Tensor
4646 The input tensor on which sparsemax will be applied.
4747 dim : int, optional
4848 Dimension along which to apply sparsemax. Default is -1.
@@ -53,8 +53,8 @@ def forward(ctx, x, dim=-1):
5353 A tensor with the same shape as the input, with sparsemax applied.
5454 """
5555 ctx .dim = dim
56- max_val , _ = x .max (dim = dim , keepdim = True )
57- x -= max_val # Numerical stability trick, as with softmax.
56+ max_val , _ = input .max (dim = dim , keepdim = True )
57+ input -= max_val # Numerical stability trick, as with softmax.
5858 tau , supp_size = SparsemaxFunction ._threshold_and_support (input , dim = dim )
5959 output = torch .clamp (input - tau , min = 0 )
6060 ctx .save_for_backward (supp_size , output )
@@ -86,7 +86,7 @@ def backward(ctx, grad_output): # type: ignore
8686 return grad_input , None
8787
8888 @staticmethod
89- def _threshold_and_support (x , dim = - 1 ):
89+ def _threshold_and_support (input , dim = - 1 ):
9090 """
9191 Computes the threshold and support for sparsemax.
9292
@@ -103,14 +103,14 @@ def _threshold_and_support(x, dim=-1):
103103 - torch.Tensor : The threshold value for sparsemax.
104104 - torch.Tensor : The support size tensor.
105105 """
106- input_srt , _ = torch .sort (x , descending = True , dim = dim )
106+ input_srt , _ = torch .sort (input , descending = True , dim = dim )
107107 input_cumsum = input_srt .cumsum (dim ) - 1
108108 rhos = _make_ix_like (input , dim )
109109 support = rhos * input_srt > input_cumsum
110110
111111 support_size = support .sum (dim = dim ).unsqueeze (dim )
112112 tau = input_cumsum .gather (dim , support_size - 1 )
113- tau /= support_size .to (x .dtype )
113+ tau /= support_size .to (input .dtype )
114114 return tau , support_size
115115
116116
0 commit comments