Skip to content

Commit ff3b72c

Browse files
committed
update gram-newton-schulz
Update muon.py Update muon.py
1 parent e59fc69 commit ff3b72c

1 file changed

Lines changed: 63 additions & 5 deletions

File tree

modules/optimizer/muon.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,61 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474
return X
7575

7676

77+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor:
78+
"""
79+
Gram Newton-Schulz iteration to compute the orthogonalization of G.
80+
Mathematically identical to standard Newton-Schulz but computes iterating
81+
on the smaller NxN Gram matrix to save up to 50% FLOPs.
82+
"""
83+
assert G.ndim == 3
84+
a, b, c = (3.4445, -4.7750, 2.0315)
85+
ns_coefficients = [(a, b, c)] * steps
86+
87+
original_shape = G.shape
88+
dtype = G.dtype
89+
90+
X = G.to(torch.float32)
91+
92+
should_transpose = X.size(-2) > X.size(-1)
93+
if should_transpose:
94+
X = X.mT
95+
96+
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
97+
X = X.to(torch.float16)
98+
99+
if X.size(-2) != X.size(-1):
100+
R = torch.bmm(X, X.mT)
101+
Q = None
102+
103+
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
104+
if i in reset_iterations and i != 0:
105+
X = torch.bmm(Q, X)
106+
R = torch.bmm(X, X.mT)
107+
Q = None
108+
109+
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
110+
if i != 0 and i not in reset_iterations:
111+
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
112+
else:
113+
Q = Z.clone()
114+
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
115+
if i < steps - 1 and (i + 1) not in reset_iterations:
116+
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
117+
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
118+
119+
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
120+
else:
121+
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
122+
A = torch.bmm(X, X.mT)
123+
B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i)
124+
X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0)
125+
126+
if should_transpose:
127+
X = X.mT
128+
129+
return X.to(dtype).view(original_shape)
130+
131+
77132
class Muon(torch.optim.Optimizer):
78133
"""
79134
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -97,10 +152,12 @@ class Muon(torch.optim.Optimizer):
97152
ns_steps: The number of Newton-Schulz iteration steps to use.
98153
"""
99154

100-
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
101-
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
155+
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5, reset_iterations=None):
156+
if reset_iterations is None:
157+
reset_iterations = [3]
158+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations)
102159
super().__init__(params, defaults)
103-
self.bf16_support_map = get_bf16_support_map()
160+
# self.bf16_support_map = get_bf16_support_map()
104161

105162
@torch.no_grad()
106163
def step(self, closure=None):
@@ -129,8 +186,9 @@ def step(self, closure=None):
129186
original_shape = g.shape
130187
if g.ndim >= 4: # for the case of conv filters
131188
g = g.view(g.size(0), g.size(1), -1)
132-
use_bf16 = self.bf16_support_map.get(g.device, False)
133-
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
189+
# use_bf16 = self.bf16_support_map.get(g.device, False)
190+
# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
191+
g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"])
134192
if group["weight_decay"] > 0:
135193
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
136194
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)

0 commit comments

Comments
 (0)