1+ import collections
12import torch
23import torch .nn as nn
34import torch .nn .functional as F
78from itertools import repeat
89from .chained_optimizer import ChainedOptimizer , OptimizerSpec
910
10- coeffs_list = [
11+
12+ # https://arxiv.org/pdf/2505.16932
13+ _unmodified_polar_express_coefficients = [
1114 (8.28721201814563 , - 23.595886519098837 , 17.300387312530933 ),
1215 (4.107059111542203 , - 2.9478499167379106 , 0.5448431082926601 ),
1316 (3.9486908534822946 , - 2.908902115962949 , 0.5518191394370137 ),
1922]
2023
2124# safety factor for numerical stability (but exclude last polynomial )
22- coeffs_list = [(a / 1.01 , b / 1.01 ** 3 , c / 1.01 ** 5 ) for (a , b , c ) in coeffs_list [: - 1 ]] + [coeffs_list [- 1 ]]
23-
24-
25- def get_bf16_support_map ():
26- bf16_support_map = {}
27-
28- if not torch .cuda .is_available ():
29- return bf16_support_map
25+ # safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05
26+ safety_factor = 1.05
27+ POLAR_EXPRESS_COEFFICIENTS = [
28+ (a / safety_factor , b / safety_factor ** 3 , c / safety_factor ** 5 )
29+ for (a , b , c ) in _unmodified_polar_express_coefficients [: - 1 ]
30+ ] + [_unmodified_polar_express_coefficients [- 1 ]]
3031
31- device_count = torch .cuda .device_count ()
32- if device_count == 0 :
33- return bf16_support_map
3432
35- for i in range (device_count ):
36- device = torch .device (f'cuda:{ i } ' )
37- major , minor = torch .cuda .get_device_capability (device )
38- bf16_support_map [device ] = (major >= 8 )
39-
40- return bf16_support_map
41-
42-
43- def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool ) -> Tensor :
33+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int ) -> Tensor :
4434 """
4535 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
4636 quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -51,29 +41,76 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
5141 performance at all relative to UV^T, where USV^T = G is the SVD.
5242 """
5343 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
54- #a, b, c = (3.4445, -4.7750, 2.0315)
5544
56- X = G .to (dtype = torch . bfloat16 if use_bf16 else torch .float32 )
45+ X = G .to (torch .float32 )
5746
5847 # Ensure spectral norm is at most 1
59- X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
48+ X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-12 )
49+ X = X .to (torch .float16 )
6050
6151 # Perform the NS iterations
62- hs = coeffs_list [: steps ] + list (repeat (coeffs_list [- 1 ], steps - len (coeffs_list )))
52+ ns_coefficients = POLAR_EXPRESS_COEFFICIENTS [: steps ] + list (repeat (POLAR_EXPRESS_COEFFICIENTS [- 1 ], steps - len (POLAR_EXPRESS_COEFFICIENTS )))
6353 if X .size (- 2 ) < X .size (- 1 ):
64- for a , b , c in hs :
54+ for i in range (steps ):
55+ a , b , c = ns_coefficients [i ]
6556 A = torch .bmm (X , X .mT )
6657 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
6758 X = torch .baddbmm (X , A , X , beta = a , alpha = 1 )
6859 else :
69- for a , b , c in hs :
60+ for i in range (steps ):
61+ a , b , c = ns_coefficients [i ]
7062 A = torch .bmm (X .mT , X )
7163 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
7264 X = torch .baddbmm (X , X , A , beta = a , alpha = 1 )
7365
7466 return X
7567
7668
69+ def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]= [2 ]) -> Tensor :
70+ """
71+ Gram Newton-Schulz iteration to compute the orthogonalization of G.
72+ Mathematically identical to standard Newton-Schulz but computes iterating
73+ on the smaller NxN Gram matrix to save up to 50% FLOPs.
74+ """
75+ assert G .ndim == 3
76+ original_shape = G .shape
77+ dtype = G .dtype
78+
79+ X = G .to (torch .float32 )
80+ X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-12 )
81+ should_transpose = X .size (- 2 ) > X .size (- 1 )
82+ if should_transpose :
83+ X = X .mT
84+ X = X .to (torch .float16 )
85+
86+ ns_coefficients = POLAR_EXPRESS_COEFFICIENTS [:steps ] + list (repeat (POLAR_EXPRESS_COEFFICIENTS [- 1 ], steps - len (POLAR_EXPRESS_COEFFICIENTS )))
87+ if X .size (- 2 ) != X .size (- 1 ):
88+ R = torch .bmm (X , X .mT )
89+ Q = None
90+ for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
91+ if i in reset_iterations and i != 0 :
92+ X = torch .bmm (Q , X )
93+ R = torch .bmm (X , X .mT )
94+ Q = None
95+ Z = torch .baddbmm (R , R , R , beta = b_i , alpha = c_i )
96+ if i != 0 and i not in reset_iterations :
97+ Q = torch .baddbmm (Q , Q , Z , beta = a_i , alpha = 1.0 )
98+ else :
99+ Q = Z .clone ()
100+ Q .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (a_i )
101+ if i < steps - 1 and (i + 1 ) not in reset_iterations :
102+ RZ = torch .baddbmm (R , R , Z , beta = a_i , alpha = 1.0 )
103+ R = torch .baddbmm (RZ , Z , RZ , beta = a_i , alpha = 1.0 )
104+ X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
105+ else :
106+ for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
107+ A = torch .bmm (X , X .mT )
108+ B = torch .baddbmm (A , A , A , beta = b_i , alpha = c_i )
109+ X = torch .baddbmm (X , B , X , beta = a_i , alpha = 1.0 )
110+
111+ return X .to (dtype ).view (original_shape )
112+
113+
77114class Muon (torch .optim .Optimizer ):
78115 """
79116 Muon - MomentUm Orthogonalized by Newton-schulz
@@ -100,7 +137,6 @@ class Muon(torch.optim.Optimizer):
100137 def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 ):
101138 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps )
102139 super ().__init__ (params , defaults )
103- self .bf16_support_map = get_bf16_support_map ()
104140
105141 @torch .no_grad ()
106142 def step (self , closure = None ):
@@ -129,8 +165,8 @@ def step(self, closure=None):
129165 original_shape = g .shape
130166 if g .ndim >= 4 : # for the case of conv filters
131167 g = g .view (g .size (0 ), g .size (1 ), - 1 )
132- use_bf16 = self . bf16_support_map . get ( g . device , False )
133- g = zeropower_via_newtonschulz5 ( g , steps = group [ "ns_steps" ], use_bf16 = use_bf16 )
168+ g = gram_newton_schulz ( g , steps = group [ "ns_steps" ] )
169+
134170 if group ["weight_decay" ] > 0 :
135171 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
136172 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
@@ -145,15 +181,20 @@ def get_params_for_muon(model) -> List[Parameter]:
145181 Returns:
146182 A list of parameters that should be optimized with muon.
147183 """
184+ excluded_module_classes = (nn .Embedding )
148185 muon_params = []
149- for module in model .modules ():
150- for name , param in module .named_parameters (recurse = False ):
186+ # BFS through all submodules and exclude parameters from certain module types
187+ queue = collections .deque ([model ])
188+ while queue :
189+ module = queue .popleft ()
190+ if isinstance (module , excluded_module_classes ):
191+ continue
192+ for param in module .parameters (recurse = False ):
151193 if not param .requires_grad :
152194 continue
153- if name == 'weight_g' :
154- continue
155- if not isinstance (module , nn .Embedding ) and param .ndim >= 2 :
195+ if param .ndim >= 2 :
156196 muon_params .append (param )
197+ queue .extend (list (module .children ()))
157198 return muon_params
158199
159200
0 commit comments