Skip to content

Commit a39677b

Browse files
yxlllcKakaruHayate
andauthored
Optimized glu (#278)
* optimize * Keep AtanGLU behavior unchanged during eval (#275) --------- Co-authored-by: Kakaru <97896816+KakaruHayate@users.noreply.github.com>
1 parent c315d38 commit a39677b

1 file changed

Lines changed: 21 additions & 2 deletions

File tree

modules/commons/common_layers.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,22 @@ def forward(self, x):
128128
return out * gate
129129

130130

131+
class ATanGLUFunction(torch.autograd.Function):
132+
@staticmethod
133+
def forward(ctx, out, gate):
134+
atan_gate = torch.atan(gate)
135+
decay_out = out / gate.square().add(1.0)
136+
ctx.save_for_backward(decay_out, atan_gate)
137+
return out * atan_gate
138+
139+
@staticmethod
140+
def backward(ctx, grad_output):
141+
decay_out, atan_gate = ctx.saved_tensors
142+
grad_out_part = grad_output * atan_gate
143+
grad_gate_part = grad_output * decay_out
144+
return grad_out_part, grad_gate_part
145+
146+
131147
class ATanGLU(nn.Module):
132148
# ArcTan-Applies the gated linear unit function.
133149
def __init__(self, dim=-1):
@@ -136,9 +152,12 @@ def __init__(self, dim=-1):
136152

137153
def forward(self, x):
138154
# out, gate = x.chunk(2, dim=self.dim)
139-
# Using torch.split instead of chunk for ONNX export compatibility.
155+
# Using torch.split instead of chunk for ONNX export compatibility.
140156
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
141-
return out * torch.atan(gate)
157+
if self.training:
158+
return ATanGLUFunction.apply(out, gate)
159+
else:
160+
return out * torch.atan(gate)
142161

143162

144163
class KaimingNormalConv1d(torch.nn.Conv1d):

0 commit comments

Comments
 (0)