@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474 return X
7575
7676
77- def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]) -> Tensor :
77+ def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ]) -> Tensor :
7878 """
7979 Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080 Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
8787 original_shape = G .shape
8888 dtype = G .dtype
8989
90- X = G .to (torch .float32 )
90+ X = G .to (dtype = torch . bfloat16 if use_bf16 else torch .float32 )
9191
9292 X = F .normalize (X , p = 2.0 , dim = (- 2 , - 1 ), eps = 1e-7 )
9393 X = X .to (torch .float16 )
@@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te
107107 Q = None
108108
109109 Z = torch .baddbmm (R , R , R , beta = b_i , alpha = c_i )
110+
110111 if i != 0 and i not in reset_iterations :
111112 Q = torch .baddbmm (Q , Q , Z , beta = a_i , alpha = 1.0 )
112113 else :
113114 Q = Z .clone ()
114115 Q .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (a_i )
116+
115117 if i < steps - 1 and (i + 1 ) not in reset_iterations :
116118 RZ = torch .baddbmm (R , R , Z , beta = a_i , alpha = 1.0 )
117119 R = torch .baddbmm (RZ , Z , RZ , beta = a_i , alpha = 1.0 )
118120
119121 X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
122+
120123 else :
121124 for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
122125 A = torch .bmm (X , X .mT )
@@ -154,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
154157 reset_iterations = [3 ]
155158 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps , reset_iterations = reset_iterations )
156159 super ().__init__ (params , defaults )
157- # self.bf16_support_map = get_bf16_support_map()
160+ self .bf16_support_map = get_bf16_support_map ()
158161
159162 @torch .no_grad ()
160163 def step (self , closure = None ):
@@ -183,9 +186,9 @@ def step(self, closure=None):
183186 original_shape = g .shape
184187 if g .ndim >= 4 : # for the case of conv filters
185188 g = g .view (g .size (0 ), g .size (1 ), - 1 )
186- # use_bf16 = self.bf16_support_map.get(g.device, False)
189+ use_bf16 = self .bf16_support_map .get (g .device , False )
187190 # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
188- g = gram_newton_schulz (g , steps = group ["ns_steps" ], reset_iterations = group ["reset_iterations" ])
191+ g = gram_newton_schulz (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 , reset_iterations = group ["reset_iterations" ])
189192 if group ["weight_decay" ] > 0 :
190193 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
191194 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
0 commit comments