Skip to content

Commit eb3b606

Browse files
committed
stabilize fp16 training
1 parent 5f7a1be commit eb3b606

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

modules/commons/common_layers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,15 @@ def forward(self, x):
114114
# out, gate = x.chunk(2, dim=self.dim)
115115
# Using torch.split instead of chunk for ONNX export compatibility.
116116
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
117-
return out * F.silu(gate)
117+
gate = F.silu(gate)
118+
if x.dtype == torch.float16:
119+
out_min, out_max = torch.aminmax(out.detach())
120+
gate_min, gate_max = torch.aminmax(gate.detach())
121+
max_abs_out = torch.max(-out_min, out_max).float()
122+
max_abs_gate = torch.max(-gate_min, gate_max).float()
123+
if max_abs_out * max_abs_gate > 65504:
124+
return (out.float() * gate.float()).clamp(-65504, 65504).half()
125+
return out * gate
118126

119127

120128
class Conv1d(torch.nn.Conv1d):

0 commit comments

Comments
 (0)