Skip to content

Commit e59fc69

Browse files
committed
muon optimizer based on polar express
1 parent a39677b commit e59fc69

1 file changed

Lines changed: 23 additions & 5 deletions

File tree

modules/optimizer/muon.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,23 @@
44
from torch import Tensor
55
from torch.nn import Module, Parameter, Embedding
66
from typing import List
7+
from itertools import repeat
78
from .chained_optimizer import ChainedOptimizer, OptimizerSpec
89

10+
coeffs_list = [
11+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
12+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
13+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
14+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
15+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
16+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
17+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
18+
(1.875, -1.25, 0.375), # subsequent coeffs equal this numerically
19+
]
20+
21+
# safety factor for numerical stability (but exclude last polynomial )
22+
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]
23+
924

1025
def get_bf16_support_map():
1126
bf16_support_map = {}
@@ -24,7 +39,7 @@ def get_bf16_support_map():
2439

2540
return bf16_support_map
2641

27-
42+
2843
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
2944
"""
3045
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -36,21 +51,22 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
3651
performance at all relative to UV^T, where USV^T = G is the SVD.
3752
"""
3853
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
39-
a, b, c = (3.4445, -4.7750, 2.0315)
54+
#a, b, c = (3.4445, -4.7750, 2.0315)
4055

4156
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
4257

4358
# Ensure spectral norm is at most 1
4459
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
4560

4661
# Perform the NS iterations
62+
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
4763
if X.size(-2) < X.size(-1):
48-
for _ in range(steps):
64+
for a, b, c in hs:
4965
A = torch.bmm(X, X.mT)
5066
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
5167
X = torch.baddbmm(X, A, X, beta=a, alpha=1)
5268
else:
53-
for _ in range(steps):
69+
for a, b, c in hs:
5470
A = torch.bmm(X.mT, X)
5571
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
5672
X = torch.baddbmm(X, X, A, beta=a, alpha=1)
@@ -131,9 +147,11 @@ def get_params_for_muon(model) -> List[Parameter]:
131147
"""
132148
muon_params = []
133149
for module in model.modules():
134-
for param in module.parameters(recurse=False):
150+
for name, param in module.named_parameters(recurse=False):
135151
if not param.requires_grad:
136152
continue
153+
if name == 'weight_g':
154+
continue
137155
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
138156
muon_params.append(param)
139157
return muon_params

0 commit comments

Comments
 (0)