Skip to content

Commit 8e71afb

Browse files
committed
fix and optimize
1 parent c4b01f6 commit 8e71afb

5 files changed

Lines changed: 33 additions & 31 deletions

File tree

configs/acoustic.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ sampling_algorithm: euler
7070
sampling_steps: 20
7171
diff_accelerator: ddim
7272
diff_speedup: 10
73-
hidden_size: 256
73+
hidden_size: 384
7474
backbone_type: 'lynxnet2'
7575
backbone_args:
7676
num_channels: 1024

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ T_start: 0.4
7777
T_start_infer: 0.4
7878
K_step: 300
7979
K_step_infer: 300
80+
hidden_size: 384
8081
backbone_type: 'lynxnet2'
8182
backbone_args:
8283
num_channels: 1024

modules/backbones/lynxnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def forward(self, x, conditioner, diffusion_step, front_cond_inject=False):
7474

7575
class LYNXNet(nn.Module):
7676
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31,
77-
activation='PReLU', dropout=0.0, strong_cond=False):
77+
activation='PReLU', dropout_rate=0.0, strong_cond=False):
7878
"""
7979
LYNXNet(Linear Gated Depthwise Separable Convolution Network)
8080
TIPS:You can control the style of the generated results by modifying the 'activation',
@@ -100,7 +100,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
100100
expansion_factor=expansion_factor,
101101
kernel_size=kernel_size,
102102
activation=activation,
103-
dropout=dropout
103+
dropout=dropout_rate
104104
)
105105
for i in range(num_layers)
106106
]

modules/backbones/lynxnet2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def forward(self, x):
3939

4040
class LYNXNet2(nn.Module):
4141
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31,
42-
dropout=0.0, use_conditioner_cache=False, glu_type='swiglu'):
42+
dropout_rate=0.0, use_conditioner_cache=False, glu_type='swiglu'):
4343
"""
4444
LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
4545
"""
@@ -65,7 +65,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
6565
dim=num_channels,
6666
expansion_factor=expansion_factor,
6767
kernel_size=kernel_size,
68-
dropout=dropout,
68+
dropout=dropout_rate,
6969
glu_type=glu_type
7070
)
7171
for i in range(num_layers)

modules/optimizer/muon.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,25 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
3737
"""
3838
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
3939
a, b, c = (3.4445, -4.7750, 2.0315)
40-
if use_bf16:
41-
X = G.bfloat16()
42-
else:
43-
X = G.float()
44-
if G.size(-2) > G.size(-1):
45-
X = X.mT
40+
41+
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
4642

4743
# Ensure spectral norm is at most 1
4844
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
4945

5046
# Perform the NS iterations
51-
for _ in range(steps):
52-
A = X @ X.mT
53-
B = torch.baddbmm(A, A, A, beta=b, alpha=c)
54-
X = torch.baddbmm(X, B, X, beta=a, alpha=1)
55-
56-
if G.size(-2) > G.size(-1):
57-
X = X.mT
58-
return X.to(G)
47+
if X.size(-2) < X.size(-1):
48+
for _ in range(steps):
49+
A = torch.bmm(X, X.mT)
50+
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
51+
X = torch.baddbmm(X, A, X, beta=a, alpha=1)
52+
else:
53+
for _ in range(steps):
54+
A = torch.bmm(X.mT, X)
55+
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
56+
X = torch.baddbmm(X, X, A, beta=a, alpha=1)
57+
58+
return X
5959

6060

6161
class Muon(torch.optim.Optimizer):
@@ -85,7 +85,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
8585
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
8686
super().__init__(params, defaults)
8787
self.bf16_support_map = get_bf16_support_map()
88-
88+
8989
@torch.no_grad()
9090
def step(self, closure=None):
9191
for group in self.param_groups:
@@ -95,28 +95,29 @@ def step(self, closure=None):
9595
state = self.state[p]
9696
if "momentum_buffer" not in state:
9797
state["momentum_buffer"] = torch.zeros_like(g)
98-
buf: Tensor = state["momentum_buffer"]
9998
key = (p.shape, p.device, p.dtype)
10099
if key not in shape_groups:
101100
shape_groups[key] = {"params": [], "grads": [], "buffers": []}
102101
shape_groups[key]["params"].append(p)
103102
shape_groups[key]["grads"].append(g)
104-
shape_groups[key]["buffers"].append(buf)
103+
shape_groups[key]["buffers"].append(state["momentum_buffer"])
105104
for key in shape_groups:
106105
group_data = shape_groups[key]
107-
g = torch.stack(group_data["grads"])
108-
buf = torch.stack(group_data["buffers"])
109-
buf.lerp_(g, 1 - group["momentum"])
110-
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
106+
p, g, buf, m = group_data["params"], group_data["grads"], group_data["buffers"], group["momentum"]
107+
torch._foreach_lerp_(buf, g, 1-m)
108+
if group["nesterov"]:
109+
torch._foreach_lerp_(g, buf, m)
110+
g = torch.stack(g)
111+
else:
112+
g = torch.stack(buf)
113+
original_shape = g.shape
111114
if g.ndim >= 4: # for the case of conv filters
112115
g = g.view(g.size(0), g.size(1), -1)
113116
use_bf16 = self.bf16_support_map.get(g.device, False)
114117
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
115-
for i, p in enumerate(group_data["params"]):
116-
if group["weight_decay"] > 0:
117-
p.data.mul_(1 - group["lr"] * group["weight_decay"])
118-
p.data.add_(g[i].view_as(p), alpha=-group["lr"] * max(g[i].size()) ** 0.5)
119-
self.state[p]["momentum_buffer"] = buf[i].clone()
118+
if group["weight_decay"] > 0:
119+
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
120+
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)
120121

121122

122123
def get_params_for_muon(model) -> List[Parameter]:

0 commit comments

Comments
 (0)