Skip to content

Commit d099e3c

Browse files
committed
support bf16 calculation
1 parent 3ae76d7 commit d099e3c

1 file changed

Lines changed: 27 additions & 4 deletions

File tree

modules/optimizer/muon.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,25 @@
77
from .chained_optimizer import ChainedOptimizer, OptimizerSpec
88

99

10-
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
10+
def get_bf16_support_map():
11+
bf16_support_map = {}
12+
13+
if not torch.cuda.is_available():
14+
return bf16_support_map
15+
16+
device_count = torch.cuda.device_count()
17+
if device_count == 0:
18+
return bf16_support_map
19+
20+
for i in range(device_count):
21+
device = torch.device(f'cuda:{i}')
22+
major, minor = torch.cuda.get_device_capability(device)
23+
bf16_support_map[device] = (major >= 8)
24+
25+
return bf16_support_map
26+
27+
28+
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
1129
"""
1230
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
1331
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -19,7 +37,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
1937
"""
2038
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
2139
a, b, c = (3.4445, -4.7750, 2.0315)
22-
X = G.float()
40+
if use_bf16:
41+
X = G.bfloat16()
42+
else:
43+
X = G.float()
2344
if G.size(-2) > G.size(-1):
2445
X = X.mT
2546

@@ -63,7 +84,8 @@ class Muon(torch.optim.Optimizer):
6384
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
6485
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
6586
super().__init__(params, defaults)
66-
87+
self.bf16_support_map = get_bf16_support_map()
88+
6789
@torch.no_grad()
6890
def step(self, closure=None):
6991
for group in self.param_groups:
@@ -88,7 +110,8 @@ def step(self, closure=None):
88110
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
89111
if g.ndim >= 4: # for the case of conv filters
90112
g = g.view(g.size(0), g.size(1), -1)
91-
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
113+
use_bf16 = self.bf16_support_map.get(g.device, False)
114+
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
92115
for i, p in enumerate(group_data["params"]):
93116
if group["weight_decay"] > 0:
94117
p.data.mul_(1 - group["lr"] * group["weight_decay"])

0 commit comments

Comments
 (0)