Skip to content

Commit f0a1c19

Browse files
committed
Update muon.py
1 parent d265842 commit f0a1c19

1 file changed

Lines changed: 80 additions & 12 deletions

File tree

modules/optimizer/muon.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,35 @@
2222
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]
2323

2424

25+
# https://x.com/YouJiacheng/status/1905861218138804534
26+
YOU_COEFFICIENTS = [
27+
[4.0848, -6.8946, 2.9270],
28+
[3.9505, -6.3029, 2.6377],
29+
[3.7418, -5.5913, 2.3037],
30+
[2.8769, -3.1427, 1.2046],
31+
[2.8366, -3.0525, 1.2012]
32+
]
33+
34+
# https://arxiv.org/pdf/2505.16932
35+
_unmodified_polar_express_coefficients = [
36+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
37+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
38+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
39+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
40+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
41+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
42+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
43+
(1.875, -1.25, 0.375), # subsequent coeffs equal this numerically
44+
]
45+
46+
# safety factor for numerical stability (but exclude last polynomial )
47+
safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05
48+
POLAR_EXPRESS_COEFFICIENTS = [
49+
(a / safety_factor , b / safety_factor**3 , c / safety_factor**5)
50+
for (a, b, c) in _unmodified_polar_express_coefficients
51+
]
52+
53+
2554
def get_bf16_support_map():
2655
bf16_support_map = {}
2756

@@ -39,8 +68,8 @@ def get_bf16_support_map():
3968

4069
return bf16_support_map
4170

42-
43-
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
71+
72+
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coefficients: List[tuple]) -> Tensor:
4473
"""
4574
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
4675
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -51,7 +80,6 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
5180
performance at all relative to UV^T, where USV^T = G is the SVD.
5281
"""
5382
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
54-
#a, b, c = (3.4445, -4.7750, 2.0315)
5583

5684
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
5785

@@ -61,28 +89,28 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
6189
# Perform the NS iterations
6290
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
6391
if X.size(-2) < X.size(-1):
64-
for a, b, c in hs:
92+
for i in range(steps):
93+
a, b, c = ns_coefficients[i]
6594
A = torch.bmm(X, X.mT)
6695
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
6796
X = torch.baddbmm(X, A, X, beta=a, alpha=1)
6897
else:
69-
for a, b, c in hs:
98+
for i in range(steps):
99+
a, b, c = ns_coefficients[i]
70100
A = torch.bmm(X.mT, X)
71101
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
72102
X = torch.baddbmm(X, X, A, beta=a, alpha=1)
73103

74104
return X
75105

76106

77-
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor:
107+
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor:
78108
"""
79109
Gram Newton-Schulz iteration to compute the orthogonalization of G.
80110
Mathematically identical to standard Newton-Schulz but computes iterating
81111
on the smaller NxN Gram matrix to save up to 50% FLOPs.
82112
"""
83113
assert G.ndim == 3
84-
a, b, c = (3.4445, -4.7750, 2.0315)
85-
ns_coefficients = [(a, b, c)] * steps
86114

87115
original_shape = G.shape
88116
dtype = G.dtype
@@ -151,12 +179,35 @@ class Muon(torch.optim.Optimizer):
151179
momentum: The momentum used by the internal SGD.
152180
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
153181
ns_steps: The number of Newton-Schulz iteration steps to use.
182+
ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
183+
use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS.
154184
"""
155185

156-
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5, reset_iterations=None):
186+
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True,
187+
ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS,
188+
use_gram_ns=True):
157189
if reset_iterations is None:
158190
reset_iterations = [3]
159-
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations)
191+
# set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
192+
193+
if ns_coefficients is None:
194+
parsed_coefficients = [(3.4445, -4.7750, 2.0315)] * ns_steps
195+
else:
196+
parsed_coefficients = list(ns_coefficients)
197+
if len(parsed_coefficients) < ns_steps:
198+
parsed_coefficients += [parsed_coefficients[-1]] * (ns_steps - len(parsed_coefficients))
199+
parsed_coefficients = parsed_coefficients[:ns_steps]
200+
201+
defaults = dict(
202+
lr=lr,
203+
weight_decay=weight_decay,
204+
momentum=momentum,
205+
nesterov=nesterov,
206+
ns_steps=ns_steps,
207+
reset_iterations=reset_iterations,
208+
ns_coefficients=parsed_coefficients,
209+
use_gram_ns=use_gram_ns
210+
)
160211
super().__init__(params, defaults)
161212
self.bf16_support_map = get_bf16_support_map()
162213

@@ -187,9 +238,26 @@ def step(self, closure=None):
187238
original_shape = g.shape
188239
if g.ndim >= 4: # for the case of conv filters
189240
g = g.view(g.size(0), g.size(1), -1)
241+
190242
use_bf16 = self.bf16_support_map.get(g.device, False)
191-
# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
192-
g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"])
243+
244+
# Dynamic NS function invocation
245+
if group["use_gram_ns"]:
246+
g = gram_newton_schulz(
247+
g,
248+
steps=group["ns_steps"],
249+
use_bf16=use_bf16,
250+
reset_iterations=group["reset_iterations"],
251+
ns_coefficients=group["ns_coefficients"]
252+
)
253+
else:
254+
g = zeropower_via_newtonschulz5(
255+
g,
256+
steps=group["ns_steps"],
257+
use_bf16=use_bf16,
258+
ns_coefficients=group["ns_coefficients"]
259+
)
260+
193261
if group["weight_decay"] > 0:
194262
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
195263
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)

0 commit comments

Comments
 (0)