Skip to content

Commit 2989560

Browse files
committed
use bf16
1 parent ff3b72c commit 2989560

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

modules/optimizer/muon.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474
return X
7575

7676

77-
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor:
77+
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor:
7878
"""
7979
Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
8787
original_shape = G.shape
8888
dtype = G.dtype
8989

90-
X = G.to(torch.float32)
90+
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
9191

9292
should_transpose = X.size(-2) > X.size(-1)
9393
if should_transpose:
@@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
107107
Q = None
108108

109109
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
110+
110111
if i != 0 and i not in reset_iterations:
111112
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
112113
else:
113114
Q = Z.clone()
114115
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
116+
115117
if i < steps - 1 and (i + 1) not in reset_iterations:
116118
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
117119
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
118120

119121
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
122+
120123
else:
121124
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
122125
A = torch.bmm(X, X.mT)
@@ -157,7 +160,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
157160
reset_iterations = [3]
158161
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations)
159162
super().__init__(params, defaults)
160-
# self.bf16_support_map = get_bf16_support_map()
163+
self.bf16_support_map = get_bf16_support_map()
161164

162165
@torch.no_grad()
163166
def step(self, closure=None):
@@ -186,9 +189,9 @@ def step(self, closure=None):
186189
original_shape = g.shape
187190
if g.ndim >= 4: # for the case of conv filters
188191
g = g.view(g.size(0), g.size(1), -1)
189-
# use_bf16 = self.bf16_support_map.get(g.device, False)
192+
use_bf16 = self.bf16_support_map.get(g.device, False)
190193
# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
191-
g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"])
194+
g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"])
192195
if group["weight_decay"] > 0:
193196
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
194197
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)

0 commit comments

Comments
 (0)