Skip to content

Commit 3a6fab7

Browse files
committed
Simplify Muon optimizer and param selection
Remove unused coefficient tables and the mud_whiten path, and streamline orthogonalization to always use gram_newton_schulz. Add collections import and switch get_params_for_muon to a BFS that excludes Embedding modules and only collects trainable params with ndim >= 2. Cast intermediate X to float16 in zeropower_via_newtonschulz5 for faster half-precision ops, and drop unused imports (itertools.repeat) and redundant method dispatch in the Muon step. These changes reduce complexity and unify the orthogonalization flow. Update muon.py Update muon.py Update muon.py Update muon.py Update muon.py
1 parent 32811e2 commit 3a6fab7

1 file changed

Lines changed: 19 additions & 112 deletions

File tree

modules/optimizer/muon.py

Lines changed: 19 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,12 @@
1+
import collections
12
import torch
23
import torch.nn as nn
34
import torch.nn.functional as F
45
from torch import Tensor
56
from torch.nn import Module, Parameter, Embedding
67
from typing import List
7-
from itertools import repeat
88
from .chained_optimizer import ChainedOptimizer, OptimizerSpec
99

10-
coeffs_list = [
11-
(8.28721201814563, -23.595886519098837, 17.300387312530933),
12-
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
13-
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
14-
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
15-
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
16-
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
17-
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
18-
(1.875, -1.25, 0.375), # subsequent coeffs equal this numerically
19-
]
20-
21-
# safety factor for numerical stability (but exclude last polynomial )
22-
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]
23-
24-
25-
# https://x.com/YouJiacheng/status/1905861218138804534
26-
YOU_COEFFICIENTS = [
27-
[4.0848, -6.8946, 2.9270],
28-
[3.9505, -6.3029, 2.6377],
29-
[3.7418, -5.5913, 2.3037],
30-
[2.8769, -3.1427, 1.2046],
31-
[2.8366, -3.0525, 1.2012]
32-
]
3310

3411
# https://arxiv.org/pdf/2505.16932
3512
_unmodified_polar_express_coefficients = [
@@ -52,7 +29,7 @@
5229
] + [_unmodified_polar_express_coefficients[-1]]
5330

5431

55-
def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor:
32+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
5633
"""
5734
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
5835
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -68,9 +45,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
6845

6946
# Ensure spectral norm is at most 1
7047
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
48+
X = X.to(torch.float16)
7149

7250
# Perform the NS iterations
73-
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
51+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps]
7452
if X.size(-2) < X.size(-1):
7553
for i in range(steps):
7654
a, b, c = ns_coefficients[i]
@@ -87,7 +65,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
8765
return X
8866

8967

90-
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor:
68+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor:
9169
"""
9270
Gram Newton-Schulz iteration to compute the orthogonalization of G.
9371
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -104,6 +82,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
10482
X = X.mT
10583
X = X.to(torch.float16)
10684

85+
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps]
10786
if X.size(-2) != X.size(-1):
10887
R = torch.bmm(X, X.mT)
10988
Q = None
@@ -131,36 +110,6 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
131110
return X.to(dtype).view(original_shape)
132111

133112

134-
def mud_whiten(G: Tensor, passes: int = 1) -> Tensor:
135-
"""
136-
MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
137-
A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
138-
Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
139-
"""
140-
assert G.ndim == 3
141-
dtype = G.dtype
142-
143-
# "triangular_solve_cuda" not implemented for 'BFloat16'
144-
X = G.to(torch.float32)
145-
146-
should_transpose = X.size(-2) > X.size(-1)
147-
if should_transpose:
148-
X = X.mT.contiguous()
149-
150-
for _ in range(passes):
151-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization
152-
G_mat = torch.bmm(X, X.mT) # Row Gram (k,k)
153-
T = torch.tril(G_mat) # Lower-triangular of Gram
154-
T.diagonal(dim1=-2, dim2=-1).clamp_min_(1e-5) # avoid T all zero
155-
X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q
156-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows
157-
158-
if should_transpose:
159-
X = X.mT
160-
161-
return X.to(dtype).contiguous()
162-
163-
164113
class Muon(torch.optim.Optimizer):
165114
"""
166115
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -182,35 +131,10 @@ class Muon(torch.optim.Optimizer):
182131
momentum: The momentum used by the internal SGD.
183132
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
184133
ns_steps: The number of Newton-Schulz iteration steps to use.
185-
ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
186-
method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud').
187134
"""
188135

189-
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True,
190-
ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS,
191-
method='gram_ns'):
192-
if reset_iterations is None:
193-
reset_iterations = [3]
194-
# set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
195-
196-
if ns_coefficients is None:
197-
parsed_coefficients = [(3.4445, -4.7750, 2.0315)] * ns_steps
198-
else:
199-
parsed_coefficients = list(ns_coefficients)
200-
if len(parsed_coefficients) < ns_steps:
201-
parsed_coefficients += [parsed_coefficients[-1]] * (ns_steps - len(parsed_coefficients))
202-
parsed_coefficients = parsed_coefficients[:ns_steps]
203-
204-
defaults = dict(
205-
lr=lr,
206-
weight_decay=weight_decay,
207-
momentum=momentum,
208-
nesterov=nesterov,
209-
ns_steps=ns_steps,
210-
reset_iterations=reset_iterations,
211-
ns_coefficients=parsed_coefficients,
212-
method=method
213-
)
136+
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
137+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
214138
super().__init__(params, defaults)
215139

216140
@torch.no_grad()
@@ -240,29 +164,7 @@ def step(self, closure=None):
240164
original_shape = g.shape
241165
if g.ndim >= 4: # for the case of conv filters
242166
g = g.view(g.size(0), g.size(1), -1)
243-
244-
# Dynamic orthogonalization method invocation
245-
method = group.get("method", "gram_ns")
246-
if method == 'gram_ns':
247-
g = gram_newton_schulz(
248-
g,
249-
steps=group["ns_steps"],
250-
reset_iterations=group["reset_iterations"],
251-
ns_coefficients=group["ns_coefficients"]
252-
)
253-
elif method == 'mud':
254-
g = mud_whiten(
255-
g,
256-
passes=1
257-
)
258-
elif method == 'ns5':
259-
g = zeropower_via_newtonschulz5(
260-
g,
261-
steps=group["ns_steps"],
262-
ns_coefficients=group["ns_coefficients"]
263-
)
264-
else:
265-
raise ValueError(f"Unknown orthogonalization method: {method}")
167+
g = gram_newton_schulz(g, steps=group["ns_steps"])
266168

267169
if group["weight_decay"] > 0:
268170
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
@@ -278,15 +180,20 @@ def get_params_for_muon(model) -> List[Parameter]:
278180
Returns:
279181
A list of parameters that should be optimized with muon.
280182
"""
183+
excluded_module_classes = (nn.Embedding)
281184
muon_params = []
282-
for module in model.modules():
283-
for name, param in module.named_parameters(recurse=False):
185+
# BFS through all submodules and exclude parameters from certain module types
186+
queue = collections.deque([model])
187+
while queue:
188+
module = queue.popleft()
189+
if isinstance(module, excluded_module_classes):
190+
continue
191+
for param in module.parameters(recurse=False):
284192
if not param.requires_grad:
285193
continue
286-
if name == 'weight_g':
287-
continue
288-
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
194+
if param.ndim >= 2:
289195
muon_params.append(param)
196+
queue.extend(list(module.children()))
290197
return muon_params
291198

292199

0 commit comments

Comments
 (0)