Skip to content

Commit 3ae76d7

Browse files
committed
avoid precision conversions
1 parent b0ae9ca commit 3ae76d7

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

modules/commons/common_layers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,13 @@ def forward(self, x):
120120
gate_min, gate_max = torch.aminmax(gate.detach())
121121
max_abs_out = torch.max(-out_min, out_max).float()
122122
max_abs_gate = torch.max(-gate_min, gate_max).float()
123-
if max_abs_out * max_abs_gate > 1000:
124-
return (out.float() * gate.float()).clamp(-1000, 1000).half()
123+
max_abs_value = max_abs_out * max_abs_gate
124+
if max_abs_value > 1000:
125+
ratio = 1000 / max_abs_value
126+
sqrt_ratio = torch.sqrt(ratio)
127+
out = out * sqrt_ratio
128+
gate = gate * sqrt_ratio
129+
return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio
125130
return out * gate
126131

127132

0 commit comments

Comments
 (0)