Skip to content

Commit 32811e2

Browse files
committed
Cleanup whitespace in muon optimizer
Remove extraneous blank lines and trailing spaces in modules/optimizer/muon.py and tidy formatting around the normalization, transpose and Newton–Schulz loops. No functional logic was changed.
1 parent 59b348e commit 32811e2

1 file changed

Lines changed: 1 addition & 10 deletions

File tree

modules/optimizer/muon.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,43 +94,34 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co
9494
on the smaller NxN Gram matrix to save up to 50% FLOPs.
9595
"""
9696
assert G.ndim == 3
97-
9897
original_shape = G.shape
9998
dtype = G.dtype
10099

101100
X = G.to(torch.float32)
102-
103101
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
104-
105102
should_transpose = X.size(-2) > X.size(-1)
106103
if should_transpose:
107104
X = X.mT
108-
109105
X = X.to(torch.float16)
106+
110107
if X.size(-2) != X.size(-1):
111108
R = torch.bmm(X, X.mT)
112109
Q = None
113-
114110
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
115111
if i in reset_iterations and i != 0:
116112
X = torch.bmm(Q, X)
117113
R = torch.bmm(X, X.mT)
118114
Q = None
119-
120115
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
121-
122116
if i != 0 and i not in reset_iterations:
123117
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
124118
else:
125119
Q = Z.clone()
126120
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
127-
128121
if i < steps - 1 and (i + 1) not in reset_iterations:
129122
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
130123
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
131-
132124
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
133-
134125
else:
135126
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
136127
A = torch.bmm(X, X.mT)

0 commit comments

Comments
 (0)