@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474 return X
7575
7676
77- def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]) -> Tensor :
77+ def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ]) -> Tensor :
7878 """
7979 Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080 Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
8787 original_shape = G .shape
8888 dtype = G .dtype
8989
90- X = G .to (torch .float32 )
90+ X = G .to (dtype = torch . bfloat16 if use_bf16 else torch .float32 )
9191
9292 should_transpose = X .size (- 2 ) > X .size (- 1 )
9393 if should_transpose :
@@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
107107 Q = None
108108
109109 Z = torch .baddbmm (R , R , R , beta = b_i , alpha = c_i )
110+
110111 if i != 0 and i not in reset_iterations :
111112 Q = torch .baddbmm (Q , Q , Z , beta = a_i , alpha = 1.0 )
112113 else :
113114 Q = Z .clone ()
114115 Q .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (a_i )
116+
115117 if i < steps - 1 and (i + 1 ) not in reset_iterations :
116118 RZ = torch .baddbmm (R , R , Z , beta = a_i , alpha = 1.0 )
117119 R = torch .baddbmm (RZ , Z , RZ , beta = a_i , alpha = 1.0 )
118120
119121 X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
122+
120123 else :
121124 for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
122125 A = torch .bmm (X , X .mT )
@@ -157,7 +160,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
157160 reset_iterations = [3 ]
158161 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps , reset_iterations = reset_iterations )
159162 super ().__init__ (params , defaults )
160- # self.bf16_support_map = get_bf16_support_map()
163+ self .bf16_support_map = get_bf16_support_map ()
161164
162165 @torch .no_grad ()
163166 def step (self , closure = None ):
@@ -186,9 +189,9 @@ def step(self, closure=None):
186189 original_shape = g .shape
187190 if g .ndim >= 4 : # for the case of conv filters
188191 g = g .view (g .size (0 ), g .size (1 ), - 1 )
189- # use_bf16 = self.bf16_support_map.get(g.device, False)
192+ use_bf16 = self .bf16_support_map .get (g .device , False )
190193 # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
191- g = gram_newton_schulz (g , steps = group ["ns_steps" ], reset_iterations = group ["reset_iterations" ])
194+ g = gram_newton_schulz (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 , reset_iterations = group ["reset_iterations" ])
192195 if group ["weight_decay" ] > 0 :
193196 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
194197 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
0 commit comments