Skip to content

Commit c7ef969

Browse files
committed
Simplify Muon optimizer and param selection
Remove unused coefficient tables and the mud_whiten path, and streamline orthogonalization to always use gram_newton_schulz. Add collections import and switch get_params_for_muon to a BFS that excludes Embedding modules and only collects trainable params with ndim >= 2. Cast intermediate X to float16 in zeropower_via_newtonschulz5 for faster half-precision ops, and drop unused imports (itertools.repeat) and redundant method dispatch in the Muon step. These changes reduce complexity and unify the orthogonalization flow. Update muon.py Update muon.py Update muon.py Update muon.py
1 parent 32811e2 commit c7ef969

1 file changed

Lines changed: 19 additions & 110 deletions

File tree

modules/optimizer/muon.py

Lines changed: 19 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,12 @@
1+
import collections
12
import torch
23
import torch.nn as nn
34
import torch.nn.functional as F
45
from torch import Tensor
56
from torch.nn import Module, Parameter, Embedding
67
from typing import List
7-
from itertools import repeat
88
from .chained_optimizer import ChainedOptimizer, OptimizerSpec
99

10-
coeffs_list = [
11-
(8.28721201814563, -23.595886519098837, 17.300387312530933),
12-
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
13-
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
14-
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
15-
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
16-
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
17-
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
18-
(1.875, -1.25, 0.375), # subsequent coeffs equal this numerically
19-
]
20-
21-
# safety factor for numerical stability (but exclude last polynomial )
22-
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]
23-
24-
25-
# https://x.com/YouJiacheng/status/1905861218138804534
26-
YOU_COEFFICIENTS = [
27-
[4.0848, -6.8946, 2.9270],
28-
[3.9505, -6.3029, 2.6377],
29-
[3.7418, -5.5913, 2.3037],
30-
[2.8769, -3.1427, 1.2046],
31-
[2.8366, -3.0525, 1.2012]
32-
]
3310

3411
# https://arxiv.org/pdf/2505.16932
3512
_unmodified_polar_express_coefficients = [
@@ -52,7 +29,7 @@
5229
] + [_unmodified_polar_express_coefficients[-1]]
5330

5431

55-
def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor:
32+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
5633
"""
5734
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
5835
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -68,9 +45,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
6845

6946
# Ensure spectral norm is at most 1
7047
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
48+
X = X.to(torch.float16)
7149

7250
# Perform the NS iterations
73-
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
51+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps]
7452
if X.size(-2) < X.size(-1):
7553
for i in range(steps):
7654
a, b, c = ns_coefficients[i]
@@ -87,7 +65,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
8765
return X
8866

8967

90-
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor:
68+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor:
9169
"""
9270
Gram Newton-Schulz iteration to compute the orthogonalization of G.
9371
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -104,6 +82,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
10482
X = X.mT
10583
X = X.to(torch.float16)
10684

85+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps]
10786
if X.size(-2) != X.size(-1):
10887
R = torch.bmm(X, X.mT)
10988
Q = None
@@ -131,36 +110,6 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
131110
return X.to(dtype).view(original_shape)
132111

133112

134-
def mud_whiten(G: Tensor, passes: int = 1) -> Tensor:
135-
"""
136-
MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
137-
A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
138-
Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
139-
"""
140-
assert G.ndim == 3
141-
dtype = G.dtype
142-
143-
# "triangular_solve_cuda" not implemented for 'BFloat16'
144-
X = G.to(torch.float32)
145-
146-
should_transpose = X.size(-2) > X.size(-1)
147-
if should_transpose:
148-
X = X.mT.contiguous()
149-
150-
for _ in range(passes):
151-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization
152-
G_mat = torch.bmm(X, X.mT) # Row Gram (k,k)
153-
T = torch.tril(G_mat) # Lower-triangular of Gram
154-
T.diagonal(dim1=-2, dim2=-1).clamp_min_(1e-5) # avoid T all zero
155-
X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q
156-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows
157-
158-
if should_transpose:
159-
X = X.mT
160-
161-
return X.to(dtype).contiguous()
162-
163-
164113
class Muon(torch.optim.Optimizer):
165114
"""
166115
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -186,31 +135,8 @@ class Muon(torch.optim.Optimizer):
186135
method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud').
187136
"""
188137

189-
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True,
190-
ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS,
191-
method='gram_ns'):
192-
if reset_iterations is None:
193-
reset_iterations = [3]
194-
# set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
195-
196-
if ns_coefficients is None:
197-
parsed_coefficients = [(3.4445, -4.7750, 2.0315)] * ns_steps
198-
else:
199-
parsed_coefficients = list(ns_coefficients)
200-
if len(parsed_coefficients) < ns_steps:
201-
parsed_coefficients += [parsed_coefficients[-1]] * (ns_steps - len(parsed_coefficients))
202-
parsed_coefficients = parsed_coefficients[:ns_steps]
203-
204-
defaults = dict(
205-
lr=lr,
206-
weight_decay=weight_decay,
207-
momentum=momentum,
208-
nesterov=nesterov,
209-
ns_steps=ns_steps,
210-
reset_iterations=reset_iterations,
211-
ns_coefficients=parsed_coefficients,
212-
method=method
213-
)
138+
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
139+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
214140
super().__init__(params, defaults)
215141

216142
@torch.no_grad()
@@ -240,29 +166,7 @@ def step(self, closure=None):
240166
original_shape = g.shape
241167
if g.ndim >= 4: # for the case of conv filters
242168
g = g.view(g.size(0), g.size(1), -1)
243-
244-
# Dynamic orthogonalization method invocation
245-
method = group.get("method", "gram_ns")
246-
if method == 'gram_ns':
247-
g = gram_newton_schulz(
248-
g,
249-
steps=group["ns_steps"],
250-
reset_iterations=group["reset_iterations"],
251-
ns_coefficients=group["ns_coefficients"]
252-
)
253-
elif method == 'mud':
254-
g = mud_whiten(
255-
g,
256-
passes=1
257-
)
258-
elif method == 'ns5':
259-
g = zeropower_via_newtonschulz5(
260-
g,
261-
steps=group["ns_steps"],
262-
ns_coefficients=group["ns_coefficients"]
263-
)
264-
else:
265-
raise ValueError(f"Unknown orthogonalization method: {method}")
169+
g = gram_newton_schulz(g, steps=group["ns_steps"])
266170

267171
if group["weight_decay"] > 0:
268172
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
@@ -278,15 +182,20 @@ def get_params_for_muon(model) -> List[Parameter]:
278182
Returns:
279183
A list of parameters that should be optimized with muon.
280184
"""
185+
excluded_module_classes = (nn.Embedding)
281186
muon_params = []
282-
for module in model.modules():
283-
for name, param in module.named_parameters(recurse=False):
187+
# BFS through all submodules and exclude parameters from certain module types
188+
queue = collections.deque([model])
189+
while queue:
190+
module = queue.popleft()
191+
if isinstance(module, excluded_module_classes):
192+
continue
193+
for param in module.parameters(recurse=False):
284194
if not param.requires_grad:
285195
continue
286-
if name == 'weight_g':
287-
continue
288-
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
196+
if param.ndim >= 2:
289197
muon_params.append(param)
198+
queue.extend(list(module.children()))
290199
return muon_params
291200

292201

0 commit comments

Comments
 (0)