@@ -128,6 +128,22 @@ def forward(self, x):
128128 return out * gate
129129
130130
131+ class ATanGLUFunction (torch .autograd .Function ):
132+ @staticmethod
133+ def forward (ctx , out , gate ):
134+ atan_gate = torch .atan (gate )
135+ decay_out = out / gate .square ().add (1.0 )
136+ ctx .save_for_backward (decay_out , atan_gate )
137+ return out * atan_gate
138+
139+ @staticmethod
140+ def backward (ctx , grad_output ):
141+ decay_out , atan_gate = ctx .saved_tensors
142+ grad_out_part = grad_output * atan_gate
143+ grad_gate_part = grad_output * decay_out
144+ return grad_out_part , grad_gate_part
145+
146+
131147class ATanGLU (nn .Module ):
132148 # ArcTan-Applies the gated linear unit function.
133149 def __init__ (self , dim = - 1 ):
@@ -136,9 +152,12 @@ def __init__(self, dim=-1):
136152
137153 def forward (self , x ):
138154 # out, gate = x.chunk(2, dim=self.dim)
139- # Using torch.split instead of chunk for ONNX export compatibility.
155+ # Using torch.split instead of chunk for ONNX export compatibility.
140156 out , gate = torch .split (x , x .size (self .dim ) // 2 , dim = self .dim )
141- return out * torch .atan (gate )
157+ if self .training :
158+ return ATanGLUFunction .apply (out , gate )
159+ else :
160+ return out * torch .atan (gate )
142161
143162
144163class KaimingNormalConv1d (torch .nn .Conv1d ):
0 commit comments