Skip to content

Commit 4cd0748

Browse files
committed
Fix muon mud dtype handling and normalization
Cast input to float32 for the triangular solve (triangular_solve_cuda not implemented for BFloat16), while preserving the original dtype and casting the result back before returning. Also corrected row normalization to use dim=1 (instead of -1) and tightened eps from 1e-7 to 1e-8. Added explanatory comment and small cleanup. Update muon.py Update muon.py
1 parent 6d3786c commit 4cd0748

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

modules/optimizer/muon.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,31 +158,34 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
158158
return X.to(dtype).view(original_shape)
159159

160160

161-
def mud(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
161+
def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
162162
"""
163163
MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
164164
A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
165165
Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve.
166166
"""
167167
assert G.ndim == 3
168-
169-
X = G.to(dtype=torch.bfloat16 if use_bf16 else torch.float32)
168+
dtype = G.dtype
169+
170+
# X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
171+
# "triangular_solve_cuda" not implemented for 'BFloat16'
172+
X = G.to(torch.float32)
170173

171174
should_transpose = X.size(-2) > X.size(-1)
172175
if should_transpose:
173-
X = X.mT
176+
X = X.mT.contiguous()
174177

175178
for _ in range(passes):
176-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Row normalization
179+
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization
177180
G_mat = torch.bmm(X, X.mT) # Row Gram (k,k)
178181
T = torch.tril(G_mat) # Lower-triangular of Gram
179182
X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q
180-
X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Renormalize rows
183+
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows
181184

182185
if should_transpose:
183186
X = X.mT
184187

185-
return X.contiguous()
188+
return X.to(dtype).contiguous()
186189

187190

188191
class Muon(torch.optim.Optimizer):
@@ -279,7 +282,7 @@ def step(self, closure=None):
279282
ns_coefficients=group["ns_coefficients"]
280283
)
281284
elif method == 'mud':
282-
g = mud(
285+
g = mud_whiten(
283286
g,
284287
passes=1,
285288
use_bf16=use_bf16

0 commit comments

Comments
 (0)