Skip to content

Commit f0fa6d9

Browse files
committed
Revert "use bf16"
This reverts commit 5b6ce35.
1 parent 2989560 commit f0fa6d9

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

modules/optimizer/muon.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474
return X
7575

7676

77-
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor:
77+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor:
7878
"""
7979
Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
8787
original_shape = G.shape
8888
dtype = G.dtype
8989

90-
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
90+
X = G.to(torch.float32)
9191

9292
should_transpose = X.size(-2) > X.size(-1)
9393
if should_transpose:
@@ -107,19 +107,16 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
107107
Q = None
108108

109109
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
110-
111110
if i != 0 and i not in reset_iterations:
112111
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
113112
else:
114113
Q = Z.clone()
115114
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
116-
117115
if i < steps - 1 and (i + 1) not in reset_iterations:
118116
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
119117
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
120118

121119
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
122-
123120
else:
124121
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
125122
A = torch.bmm(X, X.mT)
@@ -160,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
160157
reset_iterations = [3]
161158
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations)
162159
super().__init__(params, defaults)
163-
self.bf16_support_map = get_bf16_support_map()
160+
# self.bf16_support_map = get_bf16_support_map()
164161

165162
@torch.no_grad()
166163
def step(self, closure=None):
@@ -189,9 +186,9 @@ def step(self, closure=None):
189186
original_shape = g.shape
190187
if g.ndim >= 4: # for the case of conv filters
191188
g = g.view(g.size(0), g.size(1), -1)
192-
use_bf16 = self.bf16_support_map.get(g.device, False)
189+
# use_bf16 = self.bf16_support_map.get(g.device, False)
193190
# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
194-
g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"])
191+
g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"])
195192
if group["weight_decay"] > 0:
196193
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
197194
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)

0 commit comments

Comments
 (0)