77from .chained_optimizer import ChainedOptimizer , OptimizerSpec
88
99
10- def zeropower_via_newtonschulz5 (G : Tensor , steps : int ) -> Tensor :
10+ def get_bf16_support_map ():
11+ bf16_support_map = {}
12+
13+ if not torch .cuda .is_available ():
14+ return bf16_support_map
15+
16+ device_count = torch .cuda .device_count ()
17+ if device_count == 0 :
18+ return bf16_support_map
19+
20+ for i in range (device_count ):
21+ device = torch .device (f'cuda:{ i } ' )
22+ major , minor = torch .cuda .get_device_capability (device )
23+ bf16_support_map [device ] = (major >= 8 )
24+
25+ return bf16_support_map
26+
27+
28+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int , use_bf16 : bool ) -> Tensor :
1129 """
1230 Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
1331 quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
@@ -19,7 +37,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
1937 """
2038 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
2139 a , b , c = (3.4445 , - 4.7750 , 2.0315 )
22- X = G .float ()
40+ if use_bf16 :
41+ X = G .bfloat16 ()
42+ else :
43+ X = G .float ()
2344 if G .size (- 2 ) > G .size (- 1 ):
2445 X = X .mT
2546
@@ -63,7 +84,8 @@ class Muon(torch.optim.Optimizer):
6384 def __init__ (self , params , lr = 5e-4 , weight_decay = 0.1 , momentum = 0.95 , nesterov = True , ns_steps = 5 ):
6485 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps )
6586 super ().__init__ (params , defaults )
66-
87+ self .bf16_support_map = get_bf16_support_map ()
88+
6789 @torch .no_grad ()
6890 def step (self , closure = None ):
6991 for group in self .param_groups :
@@ -88,7 +110,8 @@ def step(self, closure=None):
88110 g = g .lerp_ (buf , group ["momentum" ]) if group ["nesterov" ] else buf
89111 if g .ndim >= 4 : # for the case of conv filters
90112 g = g .view (g .size (0 ), g .size (1 ), - 1 )
91- g = zeropower_via_newtonschulz5 (g , steps = group ["ns_steps" ])
113+ use_bf16 = self .bf16_support_map .get (g .device , False )
114+ g = zeropower_via_newtonschulz5 (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 )
92115 for i , p in enumerate (group_data ["params" ]):
93116 if group ["weight_decay" ] > 0 :
94117 p .data .mul_ (1 - group ["lr" ] * group ["weight_decay" ])
0 commit comments