@@ -74,6 +74,61 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474 return X
7575
7676
77+ def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]) -> Tensor :
78+ """
79+ Gram Newton-Schulz iteration to compute the orthogonalization of G.
80+ Mathematically identical to standard Newton-Schulz but computes iterating
81+ on the smaller NxN Gram matrix to save up to 50% FLOPs.
82+ """
83+ assert G .ndim == 3
84+ a , b , c = (3.4445 , - 4.7750 , 2.0315 )
85+ ns_coefficients = [(a , b , c )] * steps
86+
87+ original_shape = G .shape
88+ dtype = G .dtype
89+
90+ X = G .to (torch .float32 )
91+
92+ should_transpose = X .size (- 2 ) > X .size (- 1 )
93+ if should_transpose :
94+ X = X .mT
95+
96+ X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
97+ X = X .to (torch .float16 )
98+
99+ if X .size (- 2 ) != X .size (- 1 ):
100+ R = torch .bmm (X , X .mT )
101+ Q = None
102+
103+ for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
104+ if i in reset_iterations and i != 0 :
105+ X = torch .bmm (Q , X )
106+ R = torch .bmm (X , X .mT )
107+ Q = None
108+
109+ Z = torch .baddbmm (R , R , R , beta = b_i , alpha = c_i )
110+ if i != 0 and i not in reset_iterations :
111+ Q = torch .baddbmm (Q , Q , Z , beta = a_i , alpha = 1.0 )
112+ else :
113+ Q = Z .clone ()
114+ Q .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (a_i )
115+ if i < steps - 1 and (i + 1 ) not in reset_iterations :
116+ RZ = torch .baddbmm (R , R , Z , beta = a_i , alpha = 1.0 )
117+ R = torch .baddbmm (RZ , Z , RZ , beta = a_i , alpha = 1.0 )
118+
119+ X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
120+ else :
121+ for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
122+ A = torch .bmm (X , X .mT )
123+ B = torch .baddbmm (A , A , A , beta = b_i , alpha = c_i )
124+ X = torch .baddbmm (X , B , X , beta = a_i , alpha = 1.0 )
125+
126+ if should_transpose :
127+ X = X .mT
128+
129+ return X .to (dtype ).view (original_shape )
130+
131+
77132class Muon (torch .optim .Optimizer ):
78133 """
79134 Muon - MomentUm Orthogonalized by Newton-schulz
@@ -97,10 +152,12 @@ class Muon(torch.optim.Optimizer):
97152 ns_steps: The number of Newton-Schulz iteration steps to use.
98153 """
99154
100- def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 ):
101- defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps )
155+ def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 , reset_iterations = None ):
156+ if reset_iterations is None :
157+ reset_iterations = [3 ]
158+ defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps , reset_iterations = reset_iterations )
102159 super ().__init__ (params , defaults )
103- self .bf16_support_map = get_bf16_support_map ()
160+ # self.bf16_support_map = get_bf16_support_map()
104161
105162 @torch .no_grad ()
106163 def step (self , closure = None ):
@@ -129,8 +186,9 @@ def step(self, closure=None):
129186 original_shape = g .shape
130187 if g .ndim >= 4 : # for the case of conv filters
131188 g = g .view (g .size (0 ), g .size (1 ), - 1 )
132- use_bf16 = self .bf16_support_map .get (g .device , False )
133- g = zeropower_via_newtonschulz5 (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 )
189+ # use_bf16 = self.bf16_support_map.get(g.device, False)
190+ # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
191+ g = gram_newton_schulz (g , steps = group ["ns_steps" ], reset_iterations = group ["reset_iterations" ])
134192 if group ["weight_decay" ] > 0 :
135193 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
136194 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
0 commit comments