44from torch import Tensor
55from torch .nn import Module , Parameter , Embedding
66from typing import List
7+ from itertools import repeat
78from .chained_optimizer import ChainedOptimizer , OptimizerSpec
89
10+ coeffs_list = [
11+ (8.28721201814563 , - 23.595886519098837 , 17.300387312530933 ),
12+ (4.107059111542203 , - 2.9478499167379106 , 0.5448431082926601 ),
13+ (3.9486908534822946 , - 2.908902115962949 , 0.5518191394370137 ),
14+ (3.3184196573706015 , - 2.488488024314874 , 0.51004894012372 ),
15+ (2.300652019954817 , - 1.6689039845747493 , 0.4188073119525673 ),
16+ (1.891301407787398 , - 1.2679958271945868 , 0.37680408948524835 ),
17+ (1.8750014808534479 , - 1.2500016453999487 , 0.3750001645474248 ),
18+ (1.875 , - 1.25 , 0.375 ), # subsequent coeffs equal this numerically
19+ ]
20+
21+ # safety factor for numerical stability (but exclude last polynomial )
22+ coeffs_list = [(a / 1.01 , b / 1.01 ** 3 , c / 1.01 ** 5 ) for (a , b , c ) in coeffs_list [: - 1 ]] + [coeffs_list [- 1 ]]
23+
924
1025def get_bf16_support_map ():
1126 bf16_support_map = {}
@@ -24,7 +39,7 @@ def get_bf16_support_map():
2439
2540 return bf16_support_map
2641
27-
42+
2843def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool ) -> Tensor :
2944 """
3045 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -36,21 +51,22 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
3651 performance at all relative to UV^T, where USV^T = G is the SVD.
3752 """
3853 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
39- a , b , c = (3.4445 , - 4.7750 , 2.0315 )
54+ # a, b, c = (3.4445, -4.7750, 2.0315)
4055
4156 X = G .to (dtype = torch .bfloat16 if use_bf16 else torch .float32 )
4257
4358 # Ensure spectral norm is at most 1
4459 X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
4560
4661 # Perform the NS iterations
62+ hs = coeffs_list [: steps ] + list (repeat (coeffs_list [- 1 ], steps - len (coeffs_list )))
4763 if X .size (- 2 ) < X .size (- 1 ):
48- for _ in range ( steps ) :
64+ for a , b , c in hs :
4965 A = torch .bmm (X , X .mT )
5066 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
5167 X = torch .baddbmm (X , A , X , beta = a , alpha = 1 )
5268 else :
53- for _ in range ( steps ) :
69+ for a , b , c in hs :
5470 A = torch .bmm (X .mT , X )
5571 A = torch .baddbmm (A , A , A , beta = b , alpha = c )
5672 X = torch .baddbmm (X , X , A , beta = a , alpha = 1 )
@@ -131,9 +147,11 @@ def get_params_for_muon(model) -> List[Parameter]:
131147 """
132148 muon_params = []
133149 for module in model .modules ():
134- for param in module .parameters (recurse = False ):
150+ for name , param in module .named_parameters (recurse = False ):
135151 if not param .requires_grad :
136152 continue
153+ if name == 'weight_g' :
154+ continue
137155 if not isinstance (module , nn .Embedding ) and param .ndim >= 2 :
138156 muon_params .append (param )
139157 return muon_params
0 commit comments