We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b0ae9ca commit 3ae76d7Copy full SHA for 3ae76d7
1 file changed
modules/commons/common_layers.py
@@ -120,8 +120,13 @@ def forward(self, x):
120
gate_min, gate_max = torch.aminmax(gate.detach())
121
max_abs_out = torch.max(-out_min, out_max).float()
122
max_abs_gate = torch.max(-gate_min, gate_max).float()
123
- if max_abs_out * max_abs_gate > 1000:
124
- return (out.float() * gate.float()).clamp(-1000, 1000).half()
+ max_abs_value = max_abs_out * max_abs_gate
+ if max_abs_value > 1000:
125
+ ratio = 1000 / max_abs_value
126
+ sqrt_ratio = torch.sqrt(ratio)
127
+ out = out * sqrt_ratio
128
+ gate = gate * sqrt_ratio
129
+ return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio
130
return out * gate
131
132
0 commit comments