Skip to content

Commit 16e1989

Browse files
committed
Stabilize mud_whiten by clamping diagonal
Clamp the diagonal entries of the lower-triangular Gram matrix in mud_whiten with a minimum of 1e-5 before solving the triangular system. This prevents T from having all-zero diagonal values (which would cause singular/ill-conditioned solves) and improves numerical stability of the forward solve.
1 parent 4cd0748 commit 16e1989

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

modules/optimizer/muon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
179179
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization
180180
G_mat = torch.bmm(X, X.mT) # Row Gram (k,k)
181181
T = torch.tril(G_mat) # Lower-triangular of Gram
182+
T.diagonal(dim1=-2, dim2=-1).clamp_min_(1e-5) # avoid T all zero
182183
X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q
183184
X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows
184185

0 commit comments

Comments
 (0)