@@ -44,8 +44,8 @@ def __init__(
4444 bidirectional = False ,
4545 use_learnable_interaction = False ,
4646 layer_norm_eps = 1e-05 ,
47- AB_weight_decay = False ,
48- AB_layer_norm = True ,
47+ AD_weight_decay = False ,
48+ BC_layer_norm = True ,
4949 ):
5050 super ().__init__ ()
5151
@@ -70,8 +70,8 @@ def __init__(
7070 bidirectional ,
7171 use_learnable_interaction ,
7272 layer_norm_eps ,
73- AB_weight_decay ,
74- AB_layer_norm ,
73+ AD_weight_decay ,
74+ BC_layer_norm ,
7575 )
7676 for _ in range (n_layers )
7777 ]
@@ -112,8 +112,8 @@ def __init__(
112112 bidirectional = False ,
113113 use_learnable_interaction = False ,
114114 layer_norm_eps = 1e-05 ,
115- AB_weight_decay = False ,
116- AB_layer_norm = False ,
115+ AD_weight_decay = False ,
116+ BC_layer_norm = False ,
117117 ):
118118 super ().__init__ ()
119119
@@ -159,8 +159,8 @@ def __init__(
159159 bidirectional = bidirectional ,
160160 use_learnable_interaction = use_learnable_interaction ,
161161 layer_norm_eps = layer_norm_eps ,
162- AB_weight_decay = AB_weight_decay ,
163- AB_layer_norm = AB_layer_norm ,
162+ AD_weight_decay = AD_weight_decay ,
163+ BC_layer_norm = BC_layer_norm ,
164164 )
165165 self .norm = norm (d_model , eps = layer_norm_eps )
166166
@@ -202,8 +202,8 @@ def __init__(
202202 bidirectional = False ,
203203 use_learnable_interaction = False ,
204204 layer_norm_eps = 1e-05 ,
205- AB_weight_decay = False ,
206- AB_layer_norm = False ,
205+ AD_weight_decay = False ,
206+ BC_layer_norm = False ,
207207 ):
208208 super ().__init__ ()
209209 self .d_inner = d_model * expand_factor
@@ -284,21 +284,21 @@ def __init__(
284284 self .A_log_bwd = nn .Parameter (torch .log (A ))
285285 self .D_bwd = nn .Parameter (torch .ones (self .d_inner ))
286286
287- if not AB_weight_decay :
287+ if not AD_weight_decay :
288288 self .A_log_fwd ._no_weight_decay = True
289289 self .D_fwd ._no_weight_decay = True
290290
291291 if self .bidirectional :
292292
293- if not AB_weight_decay :
293+ if not AD_weight_decay :
294294 self .A_log_bwd ._no_weight_decay = True
295295 self .D_bwd ._no_weight_decay = True
296296
297297 self .out_proj = nn .Linear (self .d_inner , d_model , bias = bias )
298298 self .dt_rank = dt_rank
299299 self .d_state = d_state
300300
301- if AB_layer_norm :
301+ if BC_layer_norm :
302302 self .dt_layernorm = RMSNorm (self .dt_rank , eps = layer_norm_eps )
303303 self .B_layernorm = RMSNorm (self .d_state , eps = layer_norm_eps )
304304 self .C_layernorm = RMSNorm (self .d_state , eps = layer_norm_eps )
0 commit comments