@@ -158,6 +158,33 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
158158 return X .to (dtype ).view (original_shape )
159159
160160
161+ def mud (G : Tensor , passes : int = 1 , use_bf16 : bool = False ) -> Tensor :
162+ """
163+ MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
164+ A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
165+ Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
166+ """
167+ assert G .ndim == 3
168+
169+ X = G .to (dtype = torch .bfloat16 if use_bf16 else torch .float32 )
170+
171+ should_transpose = X .size (- 2 ) > X .size (- 1 )
172+ if should_transpose :
173+ X = X .mT
174+
175+ for _ in range (passes ):
176+ X = F .normalize (X , p = 2.0 , dim = - 1 , eps = 1e-7 ) # Row normalization
177+ G_mat = torch .bmm (X , X .mT ) # Row Gram (k,k)
178+ T = torch .tril (G_mat ) # Lower-triangular of Gram
179+ X = torch .linalg .solve_triangular (T , X , upper = False ) # Forward solve: T X = Q
180+ X = F .normalize (X , p = 2.0 , dim = - 1 , eps = 1e-7 ) # Renormalize rows
181+
182+ if should_transpose :
183+ X = X .mT
184+
185+ return X .contiguous ()
186+
187+
161188class Muon (torch .optim .Optimizer ):
162189 """
163190 Muon - MomentUm Orthogonalized by Newton-schulz
@@ -180,12 +207,12 @@ class Muon(torch.optim.Optimizer):
180207 nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
181208 ns_steps: The number of Newton-Schulz iteration steps to use.
182209 ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
183- use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS .
210+ method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud') .
184211 """
185212
186213 def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True ,
187214 ns_steps = 5 , reset_iterations = [2 ], ns_coefficients = POLAR_EXPRESS_COEFFICIENTS ,
188- use_gram_ns = True ):
215+ method = 'gram_ns' ):
189216 if reset_iterations is None :
190217 reset_iterations = [3 ]
191218 # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
@@ -206,7 +233,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
206233 ns_steps = ns_steps ,
207234 reset_iterations = reset_iterations ,
208235 ns_coefficients = parsed_coefficients ,
209- use_gram_ns = use_gram_ns
236+ method = method
210237 )
211238 super ().__init__ (params , defaults )
212239 self .bf16_support_map = get_bf16_support_map ()
@@ -241,22 +268,31 @@ def step(self, closure=None):
241268
242269 use_bf16 = self .bf16_support_map .get (g .device , False )
243270
244- # Dynamic NS function invocation
245- if group ["use_gram_ns" ]:
271+ # Dynamic orthogonalization method invocation
272+ method = group .get ("method" , "gram_ns" )
273+ if method == 'gram_ns' :
246274 g = gram_newton_schulz (
247275 g ,
248276 steps = group ["ns_steps" ],
249277 use_bf16 = use_bf16 ,
250278 reset_iterations = group ["reset_iterations" ],
251279 ns_coefficients = group ["ns_coefficients" ]
252280 )
253- else :
281+ elif method == 'mud' :
282+ g = mud (
283+ g ,
284+ passes = 1 ,
285+ use_bf16 = use_bf16
286+ )
287+ elif method == 'ns5' :
254288 g = zeropower_via_newtonschulz5 (
255289 g ,
256290 steps = group ["ns_steps" ],
257291 use_bf16 = use_bf16 ,
258292 ns_coefficients = group ["ns_coefficients" ]
259293 )
294+ else :
295+ raise ValueError (f"Unknown orthogonalization method: { method } " )
260296
261297 if group ["weight_decay" ] > 0 :
262298 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
0 commit comments