Skip to content

Commit 08c929a

Browse files
committed
Reapply "use bf16"
This reverts commit f097a1e.
1 parent 86a6342 commit 08c929a

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

modules/optimizer/muon.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474
return X
7575

7676

77-
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor:
77+
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor:
7878
"""
7979
Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
8787
original_shape = G.shape
8888
dtype = G.dtype
8989

90-
X = G.to(torch.float32)
90+
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
9191

9292
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
9393
X = X.to(torch.float16)
@@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
107107
Q = None
108108

109109
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
110+
110111
if i != 0 and i not in reset_iterations:
111112
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
112113
else:
113114
Q = Z.clone()
114115
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
116+
115117
if i < steps - 1 and (i + 1) not in reset_iterations:
116118
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
117119
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
118120

119121
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
122+
120123
else:
121124
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
122125
A = torch.bmm(X, X.mT)
@@ -154,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
154157
reset_iterations = [3]
155158
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations)
156159
super().__init__(params, defaults)
157-
# self.bf16_support_map = get_bf16_support_map()
160+
self.bf16_support_map = get_bf16_support_map()
158161

159162
@torch.no_grad()
160163
def step(self, closure=None):
@@ -183,9 +186,9 @@ def step(self, closure=None):
183186
original_shape = g.shape
184187
if g.ndim >= 4: # for the case of conv filters
185188
g = g.view(g.size(0), g.size(1), -1)
186-
# use_bf16 = self.bf16_support_map.get(g.device, False)
189+
use_bf16 = self.bf16_support_map.get(g.device, False)
187190
# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
188-
g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"])
191+
g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"])
189192
if group["weight_decay"] > 0:
190193
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
191194
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)

0 commit comments

Comments
 (0)