@@ -37,25 +37,25 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
3737 """
3838 assert G .ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
3939 a , b , c = (3.4445 , - 4.7750 , 2.0315 )
40- if use_bf16 :
41- X = G .bfloat16 ()
42- else :
43- X = G .float ()
44- if G .size (- 2 ) > G .size (- 1 ):
45- X = X .mT
40+
41+ X = G .to (dtype = torch .bfloat16 if use_bf16 else torch .float32 )
4642
4743 # Ensure spectral norm is at most 1
4844 X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
4945
5046 # Perform the NS iterations
51- for _ in range (steps ):
52- A = X @ X .mT
53- B = torch .baddbmm (A , A , A , beta = b , alpha = c )
54- X = torch .baddbmm (X , B , X , beta = a , alpha = 1 )
55-
56- if G .size (- 2 ) > G .size (- 1 ):
57- X = X .mT
58- return X .to (G )
47+ if X .size (- 2 ) < X .size (- 1 ):
48+ for _ in range (steps ):
49+ A = torch .bmm (X , X .mT )
50+ A = torch .baddbmm (A , A , A , beta = b , alpha = c )
51+ X = torch .baddbmm (X , A , X , beta = a , alpha = 1 )
52+ else :
53+ for _ in range (steps ):
54+ A = torch .bmm (X .mT , X )
55+ A = torch .baddbmm (A , A , A , beta = b , alpha = c )
56+ X = torch .baddbmm (X , X , A , beta = a , alpha = 1 )
57+
58+ return X
5959
6060
6161class Muon (torch .optim .Optimizer ):
@@ -85,7 +85,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
8585 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps )
8686 super ().__init__ (params , defaults )
8787 self .bf16_support_map = get_bf16_support_map ()
88-
88+
8989 @torch .no_grad ()
9090 def step (self , closure = None ):
9191 for group in self .param_groups :
@@ -95,28 +95,29 @@ def step(self, closure=None):
9595 state = self .state [p ]
9696 if "momentum_buffer" not in state :
9797 state ["momentum_buffer" ] = torch .zeros_like (g )
98- buf : Tensor = state ["momentum_buffer" ]
9998 key = (p .shape , p .device , p .dtype )
10099 if key not in shape_groups :
101100 shape_groups [key ] = {"params" : [], "grads" : [], "buffers" : []}
102101 shape_groups [key ]["params" ].append (p )
103102 shape_groups [key ]["grads" ].append (g )
104- shape_groups [key ]["buffers" ].append (buf )
103+ shape_groups [key ]["buffers" ].append (state [ "momentum_buffer" ] )
105104 for key in shape_groups :
106105 group_data = shape_groups [key ]
107- g = torch .stack (group_data ["grads" ])
108- buf = torch .stack (group_data ["buffers" ])
109- buf .lerp_ (g , 1 - group ["momentum" ])
110- g = g .lerp_ (buf , group ["momentum" ]) if group ["nesterov" ] else buf
106+ p , g , buf , m = group_data ["params" ], group_data ["grads" ], group_data ["buffers" ], group ["momentum" ]
107+ torch ._foreach_lerp_ (buf , g , 1 - m )
108+ if group ["nesterov" ]:
109+ torch ._foreach_lerp_ (g , buf , m )
110+ g = torch .stack (g )
111+ else :
112+ g = torch .stack (buf )
113+ original_shape = g .shape
111114 if g .ndim >= 4 : # for the case of conv filters
112115 g = g .view (g .size (0 ), g .size (1 ), - 1 )
113116 use_bf16 = self .bf16_support_map .get (g .device , False )
114117 g = zeropower_via_newtonschulz5 (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 )
115- for i , p in enumerate (group_data ["params" ]):
116- if group ["weight_decay" ] > 0 :
117- p .data .mul_ (1 - group ["lr" ] * group ["weight_decay" ])
118- p .data .add_ (g [i ].view_as (p ), alpha = - group ["lr" ] * max (g [i ].size ()) ** 0.5 )
119- self .state [p ]["momentum_buffer" ] = buf [i ].clone ()
118+ if group ["weight_decay" ] > 0 :
119+ torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
120+ torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
120121
121122
122123def get_params_for_muon (model ) -> List [Parameter ]:
0 commit comments