Skip to content

Commit d265842

Browse files
committed
set bf16 when X.size(-2) == X.size(-1)
1 parent 08c929a commit d265842

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

modules/optimizer/muon.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
8787
original_shape = G.shape
8888
dtype = G.dtype
8989

90-
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
90+
X = G.to(torch.float32)
9191

9292
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
93-
X = X.to(torch.float16)
9493

9594
should_transpose = X.size(-2) > X.size(-1)
9695
if should_transpose:
9796
X = X.mT
9897

9998
if X.size(-2) != X.size(-1):
99+
X = X.to(torch.float16)
100100
R = torch.bmm(X, X.mT)
101101
Q = None
102102

@@ -121,6 +121,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
121121
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
122122

123123
else:
124+
X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
124125
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
125126
A = torch.bmm(X, X.mT)
126127
B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i)

0 commit comments

Comments
 (0)