2222coeffs_list = [(a / 1.01 , b / 1.01 ** 3 , c / 1.01 ** 5 ) for (a , b , c ) in coeffs_list [: - 1 ]] + [coeffs_list [- 1 ]]
2323
2424
25+ # https://x.com/YouJiacheng/status/1905861218138804534
26+ YOU_COEFFICIENTS = [
27+ [4.0848 , - 6.8946 , 2.9270 ],
28+ [3.9505 , - 6.3029 , 2.6377 ],
29+ [3.7418 , - 5.5913 , 2.3037 ],
30+ [2.8769 , - 3.1427 , 1.2046 ],
31+ [2.8366 , - 3.0525 , 1.2012 ]
32+ ]
33+
34+ # https://arxiv.org/pdf/2505.16932
35+ _unmodified_polar_express_coefficients = [
36+ (8.28721201814563 , - 23.595886519098837 , 17.300387312530933 ),
37+ (4.107059111542203 , - 2.9478499167379106 , 0.5448431082926601 ),
38+ (3.9486908534822946 , - 2.908902115962949 , 0.5518191394370137 ),
39+ (3.3184196573706015 , - 2.488488024314874 , 0.51004894012372 ),
40+ (2.300652019954817 , - 1.6689039845747493 , 0.4188073119525673 ),
41+ (1.891301407787398 , - 1.2679958271945868 , 0.37680408948524835 ),
42+ (1.8750014808534479 , - 1.2500016453999487 , 0.3750001645474248 ),
43+ (1.875 , - 1.25 , 0.375 ), # subsequent coeffs equal this numerically
44+ ]
45+
46+ # safety factor for numerical stability (but exclude last polynomial )
47+ safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05
48+ POLAR_EXPRESS_COEFFICIENTS = [
49+ (a / safety_factor , b / safety_factor ** 3 , c / safety_factor ** 5 )
50+ for (a , b , c ) in _unmodified_polar_express_coefficients
51+ ]
52+
53+
2554def get_bf16_support_map ():
2655 bf16_support_map = {}
2756
@@ -39,8 +68,8 @@ def get_bf16_support_map():
3968
4069 return bf16_support_map
4170
42-
43- def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool ) -> Tensor :
71+
72+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool , ns_coefficients : List [ tuple ] ) -> Tensor :
4473 """
4574 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
4675 quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -51,7 +80,6 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
5180 performance at all relative to UV^T, where USV^T = G is the SVD.
5281 """
5382 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
54- #a, b, c = (3.4445, -4.7750, 2.0315)
5583
5684 X = G .to (dtype = torch .bfloat16 if use_bf16 else torch .float32 )
5785
@@ -61,28 +89,28 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
6189 # Perform the NS iterations
6290 hs = coeffs_list [: steps ] + list (repeat (coeffs_list [- 1 ], steps - len (coeffs_list )))
6391 if X .size (- 2 ) < X .size (- 1 ):
64- for a , b , c in hs :
92+ for i in range (steps ):
93+ a , b , c = ns_coefficients [i ]
6594 A = torch .bmm (X , X .mT )
6695 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
6796 X = torch .baddbmm (X , A , X , beta = a , alpha = 1 )
6897 else :
69- for a , b , c in hs :
98+ for i in range (steps ):
99+ a , b , c = ns_coefficients [i ]
70100 A = torch .bmm (X .mT , X )
71101 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
72102 X = torch .baddbmm (X , X , A , beta = a , alpha = 1 )
73103
74104 return X
75105
76106
77- def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ]) -> Tensor :
107+ def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ], ns_coefficients : List [ tuple ] ) -> Tensor :
78108 """
79109 Gram Newton-Schulz iteration to compute the orthogonalization of G.
80110 Mathematically identical to standard Newton-Schulz but computes iterating
81111 on the smaller NxN Gram matrix to save up to 50% FLOPs.
82112 """
83113 assert G .ndim == 3
84- a , b , c = (3.4445 , - 4.7750 , 2.0315 )
85- ns_coefficients = [(a , b , c )] * steps
86114
87115 original_shape = G .shape
88116 dtype = G .dtype
@@ -151,12 +179,35 @@ class Muon(torch.optim.Optimizer):
151179 momentum: The momentum used by the internal SGD.
152180 nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
153181 ns_steps: The number of Newton-Schulz iteration steps to use.
182+ ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
183+ use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS.
154184 """
155185
156- def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 , reset_iterations = None ):
186+ def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True ,
187+ ns_steps = 5 , reset_iterations = [2 ], ns_coefficients = POLAR_EXPRESS_COEFFICIENTS ,
188+ use_gram_ns = True ):
157189 if reset_iterations is None :
158190 reset_iterations = [3 ]
159- defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps , reset_iterations = reset_iterations )
191+ # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
192+
193+ if ns_coefficients is None :
194+ parsed_coefficients = [(3.4445 , - 4.7750 , 2.0315 )] * ns_steps
195+ else :
196+ parsed_coefficients = list (ns_coefficients )
197+ if len (parsed_coefficients ) < ns_steps :
198+ parsed_coefficients += [parsed_coefficients [- 1 ]] * (ns_steps - len (parsed_coefficients ))
199+ parsed_coefficients = parsed_coefficients [:ns_steps ]
200+
201+ defaults = dict (
202+ lr = lr ,
203+ weight_decay = weight_decay ,
204+ momentum = momentum ,
205+ nesterov = nesterov ,
206+ ns_steps = ns_steps ,
207+ reset_iterations = reset_iterations ,
208+ ns_coefficients = parsed_coefficients ,
209+ use_gram_ns = use_gram_ns
210+ )
160211 super ().__init__ (params , defaults )
161212 self .bf16_support_map = get_bf16_support_map ()
162213
@@ -187,9 +238,26 @@ def step(self, closure=None):
187238 original_shape = g .shape
188239 if g .ndim >= 4 : # for the case of conv filters
189240 g = g .view (g .size (0 ), g .size (1 ), - 1 )
241+
190242 use_bf16 = self .bf16_support_map .get (g .device , False )
191- # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
192- g = gram_newton_schulz (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 , reset_iterations = group ["reset_iterations" ])
243+
244+ # Dynamic NS function invocation
245+ if group ["use_gram_ns" ]:
246+ g = gram_newton_schulz (
247+ g ,
248+ steps = group ["ns_steps" ],
249+ use_bf16 = use_bf16 ,
250+ reset_iterations = group ["reset_iterations" ],
251+ ns_coefficients = group ["ns_coefficients" ]
252+ )
253+ else :
254+ g = zeropower_via_newtonschulz5 (
255+ g ,
256+ steps = group ["ns_steps" ],
257+ use_bf16 = use_bf16 ,
258+ ns_coefficients = group ["ns_coefficients" ]
259+ )
260+
193261 if group ["weight_decay" ] > 0 :
194262 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
195263 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
0 commit comments