We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d099e3c commit 0db91f2Copy full SHA for 0db91f2
1 file changed
modules/commons/common_layers.py
@@ -122,10 +122,8 @@ def forward(self, x):
122
max_abs_gate = torch.max(-gate_min, gate_max).float()
123
max_abs_value = max_abs_out * max_abs_gate
124
if max_abs_value > 1000:
125
- ratio = 1000 / max_abs_value
126
- sqrt_ratio = torch.sqrt(ratio)
127
- out = out * sqrt_ratio
128
- gate = gate * sqrt_ratio
+ ratio = (1000 / max_abs_value).half()
+ gate *= ratio
129
return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio
130
return out * gate
131
0 commit comments