Skip to content

Commit 86a6342

Browse files
committed
post-transpose
1 parent bbdc0e6 commit 86a6342

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

modules/optimizer/muon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
8989

9090
X = G.to(torch.float32)
9191

92+
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
93+
X = X.to(torch.float16)
94+
9295
should_transpose = X.size(-2) > X.size(-1)
9396
if should_transpose:
9497
X = X.mT
9598

96-
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
97-
X = X.to(torch.float16)
98-
9999
if X.size(-2) != X.size(-1):
100100
R = torch.bmm(X, X.mT)
101101
Q = None

0 commit comments

Comments
 (0)