1+ import collections
12import torch
23import torch .nn as nn
34import torch .nn .functional as F
45from torch import Tensor
56from torch .nn import Module , Parameter , Embedding
67from typing import List
7- from itertools import repeat
88from .chained_optimizer import ChainedOptimizer , OptimizerSpec
99
10- coeffs_list = [
11- (8.28721201814563 , - 23.595886519098837 , 17.300387312530933 ),
12- (4.107059111542203 , - 2.9478499167379106 , 0.5448431082926601 ),
13- (3.9486908534822946 , - 2.908902115962949 , 0.5518191394370137 ),
14- (3.3184196573706015 , - 2.488488024314874 , 0.51004894012372 ),
15- (2.300652019954817 , - 1.6689039845747493 , 0.4188073119525673 ),
16- (1.891301407787398 , - 1.2679958271945868 , 0.37680408948524835 ),
17- (1.8750014808534479 , - 1.2500016453999487 , 0.3750001645474248 ),
18- (1.875 , - 1.25 , 0.375 ), # subsequent coeffs equal this numerically
19- ]
20-
21- # safety factor for numerical stability (but exclude last polynomial )
22- coeffs_list = [(a / 1.01 , b / 1.01 ** 3 , c / 1.01 ** 5 ) for (a , b , c ) in coeffs_list [: - 1 ]] + [coeffs_list [- 1 ]]
23-
24-
25- # https://x.com/YouJiacheng/status/1905861218138804534
26- YOU_COEFFICIENTS = [
27- [4.0848 , - 6.8946 , 2.9270 ],
28- [3.9505 , - 6.3029 , 2.6377 ],
29- [3.7418 , - 5.5913 , 2.3037 ],
30- [2.8769 , - 3.1427 , 1.2046 ],
31- [2.8366 , - 3.0525 , 1.2012 ]
32- ]
3310
3411# https://arxiv.org/pdf/2505.16932
3512_unmodified_polar_express_coefficients = [
5229] + [_unmodified_polar_express_coefficients [- 1 ]]
5330
5431
55- def zeropower_via_newtonschulz5 (G : Tensor , steps : int , ns_coefficients : List [ tuple ] ) -> Tensor :
32+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int ) -> Tensor :
5633 """
5734 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
5835 quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -68,9 +45,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
6845
6946 # Ensure spectral norm is at most 1
7047 X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
48+ X = X .to (torch .float16 )
7149
7250 # Perform the NS iterations
73- hs = coeffs_list [: steps ] + list ( repeat ( coeffs_list [ - 1 ], steps - len ( coeffs_list )))
51+ ns_coefficients = POLAR_EXPRESS_COEFFICIENTS [: steps ]
7452 if X .size (- 2 ) < X .size (- 1 ):
7553 for i in range (steps ):
7654 a , b , c = ns_coefficients [i ]
@@ -87,7 +65,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup
8765 return X
8866
8967
90- def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ], ns_coefficients : List [ tuple ]) -> Tensor :
68+ def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]= [ 2 ]) -> Tensor :
9169 """
9270 Gram Newton-Schulz iteration to compute the orthogonalization of G.
9371 Mathematically identical to standard Newton-Schulz but computes iterating
@@ -104,6 +82,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
10482 X = X .mT
10583 X = X .to (torch .float16 )
10684
85+ ns_coefficients = POLAR_EXPRESS_COEFFICIENTS [:steps ]
10786 if X .size (- 2 ) != X .size (- 1 ):
10887 R = torch .bmm (X , X .mT )
10988 Q = None
@@ -131,36 +110,6 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
131110 return X .to (dtype ).view (original_shape )
132111
133112
134- def mud_whiten (G : Tensor , passes : int = 1 ) -> Tensor :
135- """
136- MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
137- A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
138- Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
139- """
140- assert G .ndim == 3
141- dtype = G .dtype
142-
143- # "triangular_solve_cuda" not implemented for 'BFloat16'
144- X = G .to (torch .float32 )
145-
146- should_transpose = X .size (- 2 ) > X .size (- 1 )
147- if should_transpose :
148- X = X .mT .contiguous ()
149-
150- for _ in range (passes ):
151- X = F .normalize (X , p = 2.0 , dim = - 1 , eps = 1e-8 ) # Row normalization
152- G_mat = torch .bmm (X , X .mT ) # Row Gram (k,k)
153- T = torch .tril (G_mat ) # Lower-triangular of Gram
154- T .diagonal (dim1 = - 2 , dim2 = - 1 ).clamp_min_ (1e-5 ) # avoid T all zero
155- X = torch .linalg .solve_triangular (T , X , upper = False ) # Forward solve: T X = Q
156- X = F .normalize (X , p = 2.0 , dim = - 1 , eps = 1e-8 ) # Renormalize rows
157-
158- if should_transpose :
159- X = X .mT
160-
161- return X .to (dtype ).contiguous ()
162-
163-
164113class Muon (torch .optim .Optimizer ):
165114 """
166115 Muon - MomentUm Orthogonalized by Newton-schulz
@@ -182,35 +131,10 @@ class Muon(torch.optim.Optimizer):
182131 momentum: The momentum used by the internal SGD.
183132 nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
184133 ns_steps: The number of Newton-Schulz iteration steps to use.
185- ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
186- method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud').
187134 """
188135
189- def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True ,
190- ns_steps = 5 , reset_iterations = [2 ], ns_coefficients = POLAR_EXPRESS_COEFFICIENTS ,
191- method = 'gram_ns' ):
192- if reset_iterations is None :
193- reset_iterations = [3 ]
194- # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
195-
196- if ns_coefficients is None :
197- parsed_coefficients = [(3.4445 , - 4.7750 , 2.0315 )] * ns_steps
198- else :
199- parsed_coefficients = list (ns_coefficients )
200- if len (parsed_coefficients ) < ns_steps :
201- parsed_coefficients += [parsed_coefficients [- 1 ]] * (ns_steps - len (parsed_coefficients ))
202- parsed_coefficients = parsed_coefficients [:ns_steps ]
203-
204- defaults = dict (
205- lr = lr ,
206- weight_decay = weight_decay ,
207- momentum = momentum ,
208- nesterov = nesterov ,
209- ns_steps = ns_steps ,
210- reset_iterations = reset_iterations ,
211- ns_coefficients = parsed_coefficients ,
212- method = method
213- )
136+ def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 ):
137+ defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps )
214138 super ().__init__ (params , defaults )
215139
216140 @torch .no_grad ()
@@ -240,29 +164,7 @@ def step(self, closure=None):
240164 original_shape = g .shape
241165 if g .ndim >= 4 : # for the case of conv filters
242166 g = g .view (g .size (0 ), g .size (1 ), - 1 )
243-
244- # Dynamic orthogonalization method invocation
245- method = group .get ("method" , "gram_ns" )
246- if method == 'gram_ns' :
247- g = gram_newton_schulz (
248- g ,
249- steps = group ["ns_steps" ],
250- reset_iterations = group ["reset_iterations" ],
251- ns_coefficients = group ["ns_coefficients" ]
252- )
253- elif method == 'mud' :
254- g = mud_whiten (
255- g ,
256- passes = 1
257- )
258- elif method == 'ns5' :
259- g = zeropower_via_newtonschulz5 (
260- g ,
261- steps = group ["ns_steps" ],
262- ns_coefficients = group ["ns_coefficients" ]
263- )
264- else :
265- raise ValueError (f"Unknown orthogonalization method: { method } " )
167+ g = gram_newton_schulz (g , steps = group ["ns_steps" ])
266168
267169 if group ["weight_decay" ] > 0 :
268170 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
@@ -278,15 +180,20 @@ def get_params_for_muon(model) -> List[Parameter]:
278180 Returns:
279181 A list of parameters that should be optimized with muon.
280182 """
183+ excluded_module_classes = (nn .Embedding )
281184 muon_params = []
282- for module in model .modules ():
283- for name , param in module .named_parameters (recurse = False ):
185+ # BFS through all submodules and exclude parameters from certain module types
186+ queue = collections .deque ([model ])
187+ while queue :
188+ module = queue .popleft ()
189+ if isinstance (module , excluded_module_classes ):
190+ continue
191+ for param in module .parameters (recurse = False ):
284192 if not param .requires_grad :
285193 continue
286- if name == 'weight_g' :
287- continue
288- if not isinstance (module , nn .Embedding ) and param .ndim >= 2 :
194+ if param .ndim >= 2 :
289195 muon_params .append (param )
196+ queue .extend (list (module .children ()))
290197 return muon_params
291198
292199
0 commit comments