Skip to content

Commit 6346d78

Browse files
committed
fix input naming in sparsemax
1 parent 50cb221 commit 6346d78

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

mambular/arch_utils/layer_utils/sparsemax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def backward(ctx, grad_output): # type: ignore
8686
return grad_input, None
8787

8888
@staticmethod
89-
def _threshold_and_support(x, dim=-1):
89+
def _threshold_and_support(input, dim=-1):
9090
"""
9191
Computes the threshold and support for sparsemax.
9292
@@ -103,14 +103,14 @@ def _threshold_and_support(x, dim=-1):
103103
- torch.Tensor : The threshold value for sparsemax.
104104
- torch.Tensor : The support size tensor.
105105
"""
106-
input_srt, _ = torch.sort(x, descending=True, dim=dim)
106+
input_srt, _ = torch.sort(input, descending=True, dim=dim)
107107
input_cumsum = input_srt.cumsum(dim) - 1
108108
rhos = _make_ix_like(input, dim)
109109
support = rhos * input_srt > input_cumsum
110110

111111
support_size = support.sum(dim=dim).unsqueeze(dim)
112112
tau = input_cumsum.gather(dim, support_size - 1)
113-
tau /= support_size.to(x.dtype)
113+
tau /= support_size.to(input.dtype)
114114
return tau, support_size
115115

116116

0 commit comments

Comments
 (0)