Skip to content

Commit 237a0e8

Browse files
committed
Remove bf16 support and simplify tensor casting
Drop bfloat16 detection and runtime BF16 paths: remove get_bf16_support_map and the bf16_support_map field, eliminate use_bf16 parameters from zeropower_via_newtonschulz5, gram_newton_schulz and mud_whiten, and stop passing use_bf16 from Muon.step. Simplify tensor casts to explicit float32/float16 usage and clean up related conditional logic. This streamlines the orthogonalization codepaths and avoids BF16-specific code (e.g. triangular_solve_cuda incompatibilities).
1 parent 931d824 commit 237a0e8

1 file changed

Lines changed: 8 additions & 34 deletions

File tree

modules/optimizer/muon.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,7 @@
5252
]
5353

5454

55-
def get_bf16_support_map():
56-
bf16_support_map = {}
57-
58-
if not torch.cuda.is_available():
59-
return bf16_support_map
60-
61-
device_count = torch.cuda.device_count()
62-
if device_count == 0:
63-
return bf16_support_map
64-
65-
for i in range(device_count):
66-
device = torch.device(f'cuda:{i}')
67-
major, minor = torch.cuda.get_device_capability(device)
68-
bf16_support_map[device] = (major >= 8)
69-
70-
return bf16_support_map
71-
72-
73-
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coefficients: List[tuple]) -> Tensor:
55+
def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor:
7456
"""
7557
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
7658
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -82,7 +64,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi
8264
"""
8365
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
8466

85-
X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
67+
X = G.to(torch.float32)
8668

8769
# Ensure spectral norm is at most 1
8870
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
@@ -105,7 +87,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi
10587
return X
10688

10789

108-
def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor:
90+
def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor:
10991
"""
11092
Gram Newton-Schulz iteration to compute the orthogonalization of G.
11193
Mathematically identical to standard Newton-Schulz but computes iterating
@@ -124,8 +106,8 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
124106
if should_transpose:
125107
X = X.mT
126108

109+
X = X.to(torch.float16)
127110
if X.size(-2) != X.size(-1):
128-
X = X.to(torch.float16)
129111
R = torch.bmm(X, X.mT)
130112
Q = None
131113

@@ -150,7 +132,6 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
150132
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
151133

152134
else:
153-
X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
154135
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
155136
A = torch.bmm(X, X.mT)
156137
B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i)
@@ -159,7 +140,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
159140
return X.to(dtype).view(original_shape)
160141

161142

162-
def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
143+
def mud_whiten(G: Tensor, passes: int = 1) -> Tensor:
163144
"""
164145
MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
165146
A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
@@ -168,7 +149,6 @@ def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
168149
assert G.ndim == 3
169150
dtype = G.dtype
170151

171-
# X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
172152
# "triangular_solve_cuda" not implemented for 'BFloat16'
173153
X = G.to(torch.float32)
174154

@@ -241,7 +221,6 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
241221
method=method
242222
)
243223
super().__init__(params, defaults)
244-
self.bf16_support_map = get_bf16_support_map()
245224

246225
@torch.no_grad()
247226
def step(self, closure=None):
@@ -271,29 +250,24 @@ def step(self, closure=None):
271250
if g.ndim >= 4: # for the case of conv filters
272251
g = g.view(g.size(0), g.size(1), -1)
273252

274-
use_bf16 = self.bf16_support_map.get(g.device, False)
275-
276253
# Dynamic orthogonalization method invocation
277254
method = group.get("method", "gram_ns")
278255
if method == 'gram_ns':
279256
g = gram_newton_schulz(
280257
g,
281-
steps=group["ns_steps"],
282-
use_bf16=use_bf16,
258+
steps=group["ns_steps"],
283259
reset_iterations=group["reset_iterations"],
284260
ns_coefficients=group["ns_coefficients"]
285261
)
286262
elif method == 'mud':
287263
g = mud_whiten(
288264
g,
289-
passes=1,
290-
use_bf16=use_bf16
265+
passes=1
291266
)
292267
elif method == 'ns5':
293268
g = zeropower_via_newtonschulz5(
294269
g,
295-
steps=group["ns_steps"],
296-
use_bf16=use_bf16,
270+
steps=group["ns_steps"],
297271
ns_coefficients=group["ns_coefficients"]
298272
)
299273
else:

0 commit comments

Comments
 (0)