@@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
7474 return X
7575
7676
77- def gram_newton_schulz (G : Tensor , steps : int , use_bf16 : bool , reset_iterations : List [int ]) -> Tensor :
77+ def gram_newton_schulz (G : Tensor , steps : int , reset_iterations : List [int ]) -> Tensor :
7878 """
7979 Gram Newton-Schulz iteration to compute the orthogonalization of G.
8080 Mathematically identical to standard Newton-Schulz but computes iterating
@@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
8787 original_shape = G .shape
8888 dtype = G .dtype
8989
90- X = G .to (dtype = torch . bfloat16 if use_bf16 else torch .float32 )
90+ X = G .to (torch .float32 )
9191
9292 should_transpose = X .size (- 2 ) > X .size (- 1 )
9393 if should_transpose :
@@ -107,19 +107,16 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations:
107107 Q = None
108108
109109 Z = torch .baddbmm (R , R , R , beta = b_i , alpha = c_i )
110-
111110 if i != 0 and i not in reset_iterations :
112111 Q = torch .baddbmm (Q , Q , Z , beta = a_i , alpha = 1.0 )
113112 else :
114113 Q = Z .clone ()
115114 Q .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (a_i )
116-
117115 if i < steps - 1 and (i + 1 ) not in reset_iterations :
118116 RZ = torch .baddbmm (R , R , Z , beta = a_i , alpha = 1.0 )
119117 R = torch .baddbmm (RZ , Z , RZ , beta = a_i , alpha = 1.0 )
120118
121119 X = torch .bmm (Q , X ) if not should_transpose else torch .bmm (X .mT , Q )
122-
123120 else :
124121 for i , (a_i , b_i , c_i ) in enumerate (ns_coefficients ):
125122 A = torch .bmm (X , X .mT )
@@ -160,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr
160157 reset_iterations = [3 ]
161158 defaults = dict (lr = lr , weight_decay = weight_decay , momentum = momentum , nesterov = nesterov , ns_steps = ns_steps , reset_iterations = reset_iterations )
162159 super ().__init__ (params , defaults )
163- self .bf16_support_map = get_bf16_support_map ()
160+ # self.bf16_support_map = get_bf16_support_map()
164161
165162 @torch .no_grad ()
166163 def step (self , closure = None ):
@@ -189,9 +186,9 @@ def step(self, closure=None):
189186 original_shape = g .shape
190187 if g .ndim >= 4 : # for the case of conv filters
191188 g = g .view (g .size (0 ), g .size (1 ), - 1 )
192- use_bf16 = self .bf16_support_map .get (g .device , False )
189+ # use_bf16 = self.bf16_support_map.get(g.device, False)
193190 # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
194- g = gram_newton_schulz (g , steps = group ["ns_steps" ], use_bf16 = use_bf16 , reset_iterations = group ["reset_iterations" ])
191+ g = gram_newton_schulz (g , steps = group ["ns_steps" ], reset_iterations = group ["reset_iterations" ])
195192 if group ["weight_decay" ] > 0 :
196193 torch ._foreach_mul_ (p , 1 - group ["lr" ] * group ["weight_decay" ])
197194 torch ._foreach_add_ (p , g .view (original_shape ).unbind (0 ), alpha = - group ["lr" ] * max (g [0 ].size ()) ** 0.5 )
0 commit comments