Skip to content

Commit e8e3332

Browse files
committed
Add MUD orthogonalization and method switch
Introduce a mud() implementation (MomentUm Decorrelation) that performs lightweight orthogonalization via row-normalization, row-gram construction, lower-triangular extraction and forward triangular solve. Update Muon optimizer to replace the boolean use_gram_ns with a string method selector (defaults to 'gram_ns') and dispatch dynamically between 'gram_ns', 'mud', and 'ns5' implementations, raising on unknown methods. Also preserve bfloat16 handling and tensor transpose logic; mud() returns a contiguous tensor.
1 parent 4e81fbb commit e8e3332

1 file changed

Lines changed: 42 additions & 6 deletions

File tree

modules/optimizer/muon.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,33 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
158158
return X.to(dtype).view(original_shape)
159159

160160

161+
def mud(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
162+
"""
163+
MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
164+
A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
165+
Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
166+
"""
167+
assert G.ndim == 3
168+
169+
X = G.to(dtype=torch.bfloat16 if use_bf16 else torch.float32)
170+
171+
should_transpose = X.size(-2) > X.size(-1)
172+
if should_transpose:
173+
X = X.mT
174+
175+
for _ in range(passes):
176+
X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Row normalization
177+
G_mat = torch.bmm(X, X.mT) # Row Gram (k,k)
178+
T = torch.tril(G_mat) # Lower-triangular of Gram
179+
X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q
180+
X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Renormalize rows
181+
182+
if should_transpose:
183+
X = X.mT
184+
185+
return X.contiguous()
186+
187+
161188
class Muon(torch.optim.Optimizer):
162189
"""
163190
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -180,12 +207,12 @@ class Muon(torch.optim.Optimizer):
180207
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
181208
ns_steps: The number of Newton-Schulz iteration steps to use.
182209
ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs.
183-
use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS.
210+
method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud').
184211
"""
185212

186213
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True,
187214
ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS,
188-
use_gram_ns=True):
215+
method='gram_ns'):
189216
if reset_iterations is None:
190217
reset_iterations = [3]
191218
# set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS
@@ -206,7 +233,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
206233
ns_steps=ns_steps,
207234
reset_iterations=reset_iterations,
208235
ns_coefficients=parsed_coefficients,
209-
use_gram_ns=use_gram_ns
236+
method=method
210237
)
211238
super().__init__(params, defaults)
212239
self.bf16_support_map = get_bf16_support_map()
@@ -241,22 +268,31 @@ def step(self, closure=None):
241268

242269
use_bf16 = self.bf16_support_map.get(g.device, False)
243270

244-
# Dynamic NS function invocation
245-
if group["use_gram_ns"]:
271+
# Dynamic orthogonalization method invocation
272+
method = group.get("method", "gram_ns")
273+
if method == 'gram_ns':
246274
g = gram_newton_schulz(
247275
g,
248276
steps=group["ns_steps"],
249277
use_bf16=use_bf16,
250278
reset_iterations=group["reset_iterations"],
251279
ns_coefficients=group["ns_coefficients"]
252280
)
253-
else:
281+
elif method == 'mud':
282+
g = mud(
283+
g,
284+
passes=1,
285+
use_bf16=use_bf16
286+
)
287+
elif method == 'ns5':
254288
g = zeropower_via_newtonschulz5(
255289
g,
256290
steps=group["ns_steps"],
257291
use_bf16=use_bf16,
258292
ns_coefficients=group["ns_coefficients"]
259293
)
294+
else:
295+
raise ValueError(f"Unknown orthogonalization method: {method}")
260296

261297
if group["weight_decay"] > 0:
262298
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])

0 commit comments

Comments
 (0)