@@ -43,11 +43,14 @@ def __init__(
4343 norm = get_normalization_layer (config ), # type: ignore
4444 activation = getattr (config , "activation" , nn .SiLU ()),
4545 bidirectional = getattr (config , "bidirectional" , False ),
46- use_learnable_interaction = getattr (config , "use_learnable_interaction" , False ),
46+ use_learnable_interaction = getattr (
47+ config , "use_learnable_interaction" , False
48+ ),
4749 layer_norm_eps = getattr (config , "layer_norm_eps" , 1e-5 ),
4850 AD_weight_decay = getattr (config , "AD_weight_decay" , True ),
4951 BC_layer_norm = getattr (config , "BC_layer_norm" , False ),
5052 use_pscan = getattr (config , "use_pscan" , False ),
53+ dilation = getattr (config , "dilation" , 1 ),
5154 )
5255 for _ in range (getattr (config , "n_layers" , 6 ))
5356 ]
@@ -149,6 +152,7 @@ def __init__(
149152 AD_weight_decay = False ,
150153 BC_layer_norm = False ,
151154 use_pscan = False ,
155+ dilation = 1 ,
152156 ):
153157 super ().__init__ ()
154158
@@ -194,6 +198,7 @@ def __init__(
194198 AD_weight_decay = AD_weight_decay ,
195199 BC_layer_norm = BC_layer_norm ,
196200 use_pscan = use_pscan ,
201+ dilation = dilation ,
197202 )
198203 self .norm = norm
199204
@@ -307,6 +312,7 @@ def __init__(
307312 AD_weight_decay = False ,
308313 BC_layer_norm = False ,
309314 use_pscan = False ,
315+ dilation = 1 ,
310316 ):
311317 super ().__init__ ()
312318
@@ -319,7 +325,10 @@ def __init__(
319325 self .pscan = pscan # Store the imported pscan function
320326 except ImportError :
321327 self .pscan = None # Set to None if pscan is not available
322- print ("The 'mambapy' package is not installed. Please install it by running:\n " "pip install mambapy" )
328+ print (
329+ "The 'mambapy' package is not installed. Please install it by running:\n "
330+ "pip install mambapy"
331+ )
323332 else :
324333 self .pscan = None
325334
@@ -347,6 +356,7 @@ def __init__(
347356 bias = conv_bias ,
348357 groups = self .d_inner ,
349358 padding = d_conv - 1 ,
359+ dilation = dilation ,
350360 )
351361
352362 self .dropout = nn .Dropout (dropout )
@@ -375,16 +385,18 @@ def __init__(
375385 else :
376386 raise NotImplementedError
377387
378- dt_fwd = torch .exp (torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min )) + math .log (dt_min )).clamp (
379- min = dt_init_floor
380- )
388+ dt_fwd = torch .exp (
389+ torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min ))
390+ + math .log (dt_min )
391+ ).clamp (min = dt_init_floor )
381392 inv_dt_fwd = dt_fwd + torch .log (- torch .expm1 (- dt_fwd ))
382393 with torch .no_grad ():
383394 self .dt_proj_fwd .bias .copy_ (inv_dt_fwd )
384395
385396 if self .bidirectional :
386397 dt_bwd = torch .exp (
387- torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min )) + math .log (dt_min )
398+ torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min ))
399+ + math .log (dt_min )
388400 ).clamp (min = dt_init_floor )
389401 inv_dt_bwd = dt_bwd + torch .log (- torch .expm1 (- dt_bwd ))
390402 with torch .no_grad ():
0 commit comments