5252]
5353
5454
55- def get_bf16_support_map ():
56- bf16_support_map = {}
57-
58- if not torch .cuda .is_available ():
59- return bf16_support_map
60-
61- device_count = torch .cuda .device_count ()
62- if device_count == 0 :
63- return bf16_support_map
64-
65- for i in range (device_count ):
66- device = torch .device (f'cuda:{ i } ' )
67- major , minor = torch .cuda .get_device_capability (device )
68- bf16_support_map [device ] = (major >= 8 )
69-
70- return bf16_support_map
71-
72-
73- def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool , ns_coefficients : List [tuple ]) -> Tensor :
55+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int , ns_coefficients : List [tuple ]) -> Tensor :
7456 """
7557 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
7658 quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -82,7 +64,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi
8264 """
8365 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
8466
85- X = G .to (dtype = torch . bfloat16 if use_bf16 else torch .float32 )
67+ X = G .to (torch .float32 )
8668
8769 # Ensure spectral norm is at most 1
8870 X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
@@ -105,7 +87,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi
10587 return X
10688
10789
108- def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ], ns_coefficients : List [tuple ]) -> Tensor :
90+ def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ], ns_coefficients : List [tuple ]) -> Tensor :
10991 """
11092 Gram Newton-Schulz iteration to compute the orthogonalization of G.
11193 Mathematically identical to standard Newton-Schulz but computes iterating
@@ -124,8 +106,8 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
124106 if should_transpose :
125107 X = X .mT
126108
109+ X = X .to (torch .float16 )
127110 if X .size (- 2 ) != X .size (- 1 ):
128- X = X .to (torch .float16 )
129111 R = torch .bmm (X , X .mT )
130112 Q = None
131113
@@ -150,7 +132,6 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
150132 X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
151133
152134 else :
153- X = X .to (dtype = torch .bfloat16 if use_bf16 else torch .float32 )
154135 for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
155136 A = torch .bmm (X , X .mT )
156137 B = torch .baddbmm (A , A , A , beta = b_i , alpha = c_i )
@@ -159,7 +140,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
159140 return X .to (dtype ).view (original_shape )
160141
161142
162- def mud_whiten (G : Tensor , passes : int = 1 , use_bf16 : bool = False ) -> Tensor :
143+ def mud_whiten (G : Tensor , passes : int = 1 ) -> Tensor :
163144 """
164145 MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G.
165146 A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training".
@@ -168,7 +149,6 @@ def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor:
168149 assert G .ndim == 3
169150 dtype = G .dtype
170151
171- # X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
172152 # "triangular_solve_cuda" not implemented for 'BFloat16'
173153 X = G .to (torch .float32 )
174154
@@ -241,7 +221,6 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
241221 method = method
242222 )
243223 super ().__init__ (params , defaults )
244- self .bf16_support_map = get_bf16_support_map ()
245224
246225 @torch .no_grad ()
247226 def step (self , closure = None ):
@@ -271,29 +250,24 @@ def step(self, closure=None):
271250 if g .ndim >= 4 : # for the case of conv filters
272251 g = g .view (g .size (0 ), g .size (1 ), - 1 )
273252
274- use_bf16 = self .bf16_support_map .get (g .device , False )
275-
276253 # Dynamic orthogonalization method invocation
277254 method = group .get ("method" , "gram_ns" )
278255 if method == 'gram_ns' :
279256 g = gram_newton_schulz (
280257 g ,
281- steps = group ["ns_steps" ],
282- use_bf16 = use_bf16 ,
258+ steps = group ["ns_steps" ],
283259 reset_iterations = group ["reset_iterations" ],
284260 ns_coefficients = group ["ns_coefficients" ]
285261 )
286262 elif method == 'mud' :
287263 g = mud_whiten (
288264 g ,
289- passes = 1 ,
290- use_bf16 = use_bf16
265+ passes = 1
291266 )
292267 elif method == 'ns5' :
293268 g = zeropower_via_newtonschulz5 (
294269 g ,
295- steps = group ["ns_steps" ],
296- use_bf16 = use_bf16 ,
270+ steps = group ["ns_steps" ],
297271 ns_coefficients = group ["ns_coefficients" ]
298272 )
299273 else :
0 commit comments