Skip to content

Commit f50cfdd

Browse files
authored
Gram Newton Schulz (#297)
1 parent e59fc69 commit f50cfdd

1 file changed

Lines changed: 76 additions & 35 deletions

File tree

modules/optimizer/muon.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import torch
23
import torch.nn as nn
34
import torch.nn.functional as F
@@ -7,7 +8,9 @@
78
from itertools import repeat
89
from .chained_optimizer import ChainedOptimizer, OptimizerSpec
910

10-
coeffs_list = [
11+
12+
# https://arxiv.org/pdf/2505.16932
13+
_unmodified_polar_express_coefficients = [
1114
(8.28721201814563, -23.595886519098837, 17.300387312530933),
1215
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
1316
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
@@ -19,28 +22,15 @@
1922
]
2023

2124
# safety factor for numerical stability (but exclude last polynomial )
22-
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]
23-
24-
25-
def get_bf16_support_map():
26-
bf16_support_map = {}
27-
28-
if not torch.cuda.is_available():
29-
return bf16_support_map
25+
# safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05
26+
safety_factor = 1.05
27+
POLAR_EXPRESS_COEFFICIENTS = [
28+
(a / safety_factor , b / safety_factor**3 , c / safety_factor**5)
29+
for (a, b, c) in _unmodified_polar_express_coefficients[: -1]
30+
] + [_unmodified_polar_express_coefficients[-1]]
3031

31-
device_count = torch.cuda.device_count()
32-
if device_count == 0:
33-
return bf16_support_map
3432

35-
for i in range(device_count):
36-
device = torch.device(f'cuda:{i}')
37-
major, minor = torch.cuda.get_device_capability(device)
38-
bf16_support_map[device] = (major >= 8)
39-
40-
return bf16_support_map
41-
42-
43-
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
33+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
4434
"""
4535
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
4636
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -51,29 +41,76 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
5141
performance at all relative to UV^T, where USV^T = G is the SVD.
5242
"""
5343
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
54-
#a, b, c = (3.4445, -4.7750, 2.0315)
5544

56-
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
45+
X = G.to(torch.float32)
5746

5847
# Ensure spectral norm is at most 1
59-
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
48+
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-12)
49+
X = X.to(torch.float16)
6050

6151
# Perform the NS iterations
62-
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
52+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS)))
6353
if X.size(-2) < X.size(-1):
64-
for a, b, c in hs:
54+
for i in range(steps):
55+
a, b, c = ns_coefficients[i]
6556
A = torch.bmm(X, X.mT)
6657
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
6758
X = torch.baddbmm(X, A, X, beta=a, alpha=1)
6859
else:
69-
for a, b, c in hs:
60+
for i in range(steps):
61+
a, b, c = ns_coefficients[i]
7062
A = torch.bmm(X.mT, X)
7163
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
7264
X = torch.baddbmm(X, X, A, beta=a, alpha=1)
7365

7466
return X
7567

7668

69+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor:
70+
"""
71+
Gram Newton-Schulz iteration to compute the orthogonalization of G.
72+
Mathematically identical to standard Newton-Schulz but computes iterating
73+
on the smaller NxN Gram matrix to save up to 50% FLOPs.
74+
"""
75+
assert G.ndim == 3
76+
original_shape = G.shape
77+
dtype = G.dtype
78+
79+
X = G.to(torch.float32)
80+
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-12)
81+
should_transpose = X.size(-2) > X.size(-1)
82+
if should_transpose:
83+
X = X.mT
84+
X = X.to(torch.float16)
85+
86+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS)))
87+
if X.size(-2) != X.size(-1):
88+
R = torch.bmm(X, X.mT)
89+
Q = None
90+
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
91+
if i in reset_iterations and i != 0:
92+
X = torch.bmm(Q, X)
93+
R = torch.bmm(X, X.mT)
94+
Q = None
95+
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
96+
if i != 0 and i not in reset_iterations:
97+
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
98+
else:
99+
Q = Z.clone()
100+
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
101+
if i < steps - 1 and (i + 1) not in reset_iterations:
102+
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
103+
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
104+
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
105+
else:
106+
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
107+
A = torch.bmm(X, X.mT)
108+
B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i)
109+
X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0)
110+
111+
return X.to(dtype).view(original_shape)
112+
113+
77114
class Muon(torch.optim.Optimizer):
78115
"""
79116
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -100,7 +137,6 @@ class Muon(torch.optim.Optimizer):
100137
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
101138
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
102139
super().__init__(params, defaults)
103-
self.bf16_support_map = get_bf16_support_map()
104140

105141
@torch.no_grad()
106142
def step(self, closure=None):
@@ -129,8 +165,8 @@ def step(self, closure=None):
129165
original_shape = g.shape
130166
if g.ndim >= 4: # for the case of conv filters
131167
g = g.view(g.size(0), g.size(1), -1)
132-
use_bf16 = self.bf16_support_map.get(g.device, False)
133-
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
168+
g = gram_newton_schulz(g, steps=group["ns_steps"])
169+
134170
if group["weight_decay"] > 0:
135171
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
136172
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)
@@ -145,15 +181,20 @@ def get_params_for_muon(model) -> List[Parameter]:
145181
Returns:
146182
A list of parameters that should be optimized with muon.
147183
"""
184+
excluded_module_classes = (nn.Embedding)
148185
muon_params = []
149-
for module in model.modules():
150-
for name, param in module.named_parameters(recurse=False):
186+
# BFS through all submodules and exclude parameters from certain module types
187+
queue = collections.deque([model])
188+
while queue:
189+
module = queue.popleft()
190+
if isinstance(module, excluded_module_classes):
191+
continue
192+
for param in module.parameters(recurse=False):
151193
if not param.requires_grad:
152194
continue
153-
if name == 'weight_g':
154-
continue
155-
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
195+
if param.ndim >= 2:
156196
muon_params.append(param)
197+
queue.extend(list(module.children()))
157198
return muon_params
158199

159200

0 commit comments

Comments
 (0)