Skip to content

Commit 88922e6

Browse files
authored
Merge pull request #193 from basf/sparse_fix
fix input names in sparsemax
2 parents 42cc9fd + cf185e6 commit 88922e6

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

mambular/arch_utils/layer_utils/sparsemax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class SparsemaxFunction(Function):
3636
"""
3737

3838
@staticmethod
39-
def forward(ctx, x, dim=-1):
39+
def forward(ctx, input, dim=-1):
4040
"""
4141
Forward pass of sparsemax: a normalizing, sparse transformation.
4242
4343
Parameters
4444
----------
45-
x : torch.Tensor
45+
input : torch.Tensor
4646
The input tensor on which sparsemax will be applied.
4747
dim : int, optional
4848
Dimension along which to apply sparsemax. Default is -1.
@@ -53,8 +53,8 @@ def forward(ctx, x, dim=-1):
5353
A tensor with the same shape as the input, with sparsemax applied.
5454
"""
5555
ctx.dim = dim
56-
max_val, _ = x.max(dim=dim, keepdim=True)
57-
x -= max_val # Numerical stability trick, as with softmax.
56+
max_val, _ = input.max(dim=dim, keepdim=True)
57+
input -= max_val # Numerical stability trick, as with softmax.
5858
tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
5959
output = torch.clamp(input - tau, min=0)
6060
ctx.save_for_backward(supp_size, output)

0 commit comments

Comments
 (0)