@@ -43,6 +43,9 @@ def __init__(
4343 activation = F .silu ,
4444 bidirectional = False ,
4545 use_learnable_interaction = False ,
46+ layer_norm_eps = 1e-05 ,
47+ AB_weight_decay = False ,
48+ AB_layer_norm = True ,
4649 ):
4750 super ().__init__ ()
4851
@@ -66,6 +69,9 @@ def __init__(
6669 activation ,
6770 bidirectional ,
6871 use_learnable_interaction ,
72+ layer_norm_eps ,
73+ AB_weight_decay ,
74+ AB_layer_norm ,
6975 )
7076 for _ in range (n_layers )
7177 ]
@@ -105,6 +111,9 @@ def __init__(
105111 activation = F .silu ,
106112 bidirectional = False ,
107113 use_learnable_interaction = False ,
114+ layer_norm_eps = 1e-05 ,
115+ AB_weight_decay = False ,
116+ AB_layer_norm = False ,
108117 ):
109118 super ().__init__ ()
110119
@@ -149,8 +158,11 @@ def __init__(
149158 activation = activation ,
150159 bidirectional = bidirectional ,
151160 use_learnable_interaction = use_learnable_interaction ,
161+ layer_norm_eps = layer_norm_eps ,
162+ AB_weight_decay = AB_weight_decay ,
163+ AB_layer_norm = AB_layer_norm ,
152164 )
153- self .norm = norm (d_model )
165+ self .norm = norm (d_model , eps = layer_norm_eps )
154166
155167 def forward (self , x ):
156168 output = self .layers (self .norm (x )) + x
@@ -189,6 +201,9 @@ def __init__(
189201 activation = F .silu ,
190202 bidirectional = False ,
191203 use_learnable_interaction = False ,
204+ layer_norm_eps = 1e-05 ,
205+ AB_weight_decay = False ,
206+ AB_layer_norm = False ,
192207 ):
193208 super ().__init__ ()
194209 self .d_inner = d_model * expand_factor
@@ -239,6 +254,7 @@ def __init__(
239254 elif dt_init == "random" :
240255 nn .init .uniform_ (self .dt_proj_fwd .weight , - dt_init_std , dt_init_std )
241256 if self .bidirectional :
257+
242258 nn .init .uniform_ (self .dt_proj_bwd .weight , - dt_init_std , dt_init_std )
243259 else :
244260 raise NotImplementedError
@@ -262,17 +278,35 @@ def __init__(
262278
263279 A = torch .arange (1 , d_state + 1 , dtype = torch .float32 ).repeat (self .d_inner , 1 )
264280 self .A_log_fwd = nn .Parameter (torch .log (A ))
281+ self .D_fwd = nn .Parameter (torch .ones (self .d_inner ))
282+
265283 if self .bidirectional :
266284 self .A_log_bwd = nn .Parameter (torch .log (A ))
285+ self .D_bwd = nn .Parameter (torch .ones (self .d_inner ))
286+
287+ if not AB_weight_decay :
288+ self .A_log_fwd ._no_weight_decay = True
289+ self .D_fwd ._no_weight_decay = True
267290
268- self .D_fwd = nn .Parameter (torch .ones (self .d_inner ))
269291 if self .bidirectional :
270- self .D_bwd = nn .Parameter (torch .ones (self .d_inner ))
292+
293+ if not AB_weight_decay :
294+ self .A_log_bwd ._no_weight_decay = True
295+ self .D_bwd ._no_weight_decay = True
271296
272297 self .out_proj = nn .Linear (self .d_inner , d_model , bias = bias )
273298 self .dt_rank = dt_rank
274299 self .d_state = d_state
275300
301+ if AB_layer_norm :
302+ self .dt_layernorm = RMSNorm (self .dt_rank , eps = layer_norm_eps )
303+ self .B_layernorm = RMSNorm (self .d_state , eps = layer_norm_eps )
304+ self .C_layernorm = RMSNorm (self .d_state , eps = layer_norm_eps )
305+ else :
306+ self .dt_layernorm = None
307+ self .B_layernorm = None
308+ self .C_layernorm = None
309+
276310 def forward (self , x ):
277311 _ , L , _ = x .shape
278312
@@ -316,6 +350,15 @@ def forward(self, x):
316350
317351 return output
318352
353+ def _apply_layernorms (self , dt , B , C ):
354+ if self .dt_layernorm is not None :
355+ dt = self .dt_layernorm (dt )
356+ if self .B_layernorm is not None :
357+ B = self .B_layernorm (B )
358+ if self .C_layernorm is not None :
359+ C = self .C_layernorm (C )
360+ return dt , B , C
361+
319362 def ssm (self , x , forward = True ):
320363 if forward :
321364 A = - torch .exp (self .A_log_fwd .float ())
@@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
324367 delta , B , C = torch .split (
325368 deltaBC , [self .dt_rank , self .d_state , self .d_state ], dim = - 1
326369 )
370+ delta , B , C = self ._apply_layernorms (delta , B , C )
327371 delta = F .softplus (self .dt_proj_fwd (delta ))
328372 else :
329373 A = - torch .exp (self .A_log_bwd .float ())
@@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
332376 delta , B , C = torch .split (
333377 deltaBC , [self .dt_rank , self .d_state , self .d_state ], dim = - 1
334378 )
379+ delta , B , C = self ._apply_layernorms (delta , B , C )
335380 delta = F .softplus (self .dt_proj_bwd (delta ))
336381
337382 y = self .selective_scan_seq (x , delta , A , B , C , D )
0 commit comments