Skip to content

Commit 0db91f2

Browse files
committed
save memory
1 parent d099e3c commit 0db91f2

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

modules/commons/common_layers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,8 @@ def forward(self, x):
122122
max_abs_gate = torch.max(-gate_min, gate_max).float()
123123
max_abs_value = max_abs_out * max_abs_gate
124124
if max_abs_value > 1000:
125-
ratio = 1000 / max_abs_value
126-
sqrt_ratio = torch.sqrt(ratio)
127-
out = out * sqrt_ratio
128-
gate = gate * sqrt_ratio
125+
ratio = (1000 / max_abs_value).half()
126+
gate *= ratio
129127
return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio
130128
return out * gate
131129

0 commit comments

Comments
 (0)