Skip to content

Commit f954a31

Browse files
committed
include dilation to 1DConv layers
1 parent a5beaed commit f954a31

4 files changed

Lines changed: 35 additions & 8 deletions

File tree

mambular/arch_utils/mamba_utils/mamba_arch.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@ def __init__(
4343
norm=get_normalization_layer(config), # type: ignore
4444
activation=getattr(config, "activation", nn.SiLU()),
4545
bidirectional=getattr(config, "bidirectional", False),
46-
use_learnable_interaction=getattr(config, "use_learnable_interaction", False),
46+
use_learnable_interaction=getattr(
47+
config, "use_learnable_interaction", False
48+
),
4749
layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5),
4850
AD_weight_decay=getattr(config, "AD_weight_decay", True),
4951
BC_layer_norm=getattr(config, "BC_layer_norm", False),
5052
use_pscan=getattr(config, "use_pscan", False),
53+
dilation=getattr(config, "dilation", 1),
5154
)
5255
for _ in range(getattr(config, "n_layers", 6))
5356
]
@@ -149,6 +152,7 @@ def __init__(
149152
AD_weight_decay=False,
150153
BC_layer_norm=False,
151154
use_pscan=False,
155+
dilation=1,
152156
):
153157
super().__init__()
154158

@@ -194,6 +198,7 @@ def __init__(
194198
AD_weight_decay=AD_weight_decay,
195199
BC_layer_norm=BC_layer_norm,
196200
use_pscan=use_pscan,
201+
dilation=dilation,
197202
)
198203
self.norm = norm
199204

@@ -307,6 +312,7 @@ def __init__(
307312
AD_weight_decay=False,
308313
BC_layer_norm=False,
309314
use_pscan=False,
315+
dilation=1,
310316
):
311317
super().__init__()
312318

@@ -319,7 +325,10 @@ def __init__(
319325
self.pscan = pscan # Store the imported pscan function
320326
except ImportError:
321327
self.pscan = None # Set to None if pscan is not available
322-
print("The 'mambapy' package is not installed. Please install it by running:\n" "pip install mambapy")
328+
print(
329+
"The 'mambapy' package is not installed. Please install it by running:\n"
330+
"pip install mambapy"
331+
)
323332
else:
324333
self.pscan = None
325334

@@ -347,6 +356,7 @@ def __init__(
347356
bias=conv_bias,
348357
groups=self.d_inner,
349358
padding=d_conv - 1,
359+
dilation=dilation,
350360
)
351361

352362
self.dropout = nn.Dropout(dropout)
@@ -375,16 +385,18 @@ def __init__(
375385
else:
376386
raise NotImplementedError
377387

378-
dt_fwd = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp(
379-
min=dt_init_floor
380-
)
388+
dt_fwd = torch.exp(
389+
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
390+
+ math.log(dt_min)
391+
).clamp(min=dt_init_floor)
381392
inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
382393
with torch.no_grad():
383394
self.dt_proj_fwd.bias.copy_(inv_dt_fwd)
384395

385396
if self.bidirectional:
386397
dt_bwd = torch.exp(
387-
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
398+
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
399+
+ math.log(dt_min)
388400
).clamp(min=dt_init_floor)
389401
inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
390402
with torch.no_grad():

mambular/arch_utils/rnn_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, config):
2121
self.rnn_activation = getattr(config, "rnn_activation", "relu")
2222
self.d_conv = getattr(config, "d_conv", 4)
2323
self.residuals = getattr(config, "residuals", False)
24+
self.dilation = getattr(config, "dilation", 1)
2425

2526
# Choose RNN layer based on model_type
2627
rnn_layer = {
@@ -37,7 +38,10 @@ def __init__(self, config):
3738

3839
if self.residuals:
3940
self.residual_matrix = nn.ParameterList(
40-
[nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) for _ in range(self.num_layers)]
41+
[
42+
nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
43+
for _ in range(self.num_layers)
44+
]
4145
)
4246

4347
# First Conv1d layer uses input_size
@@ -49,6 +53,7 @@ def __init__(self, config):
4953
padding=self.d_conv - 1,
5054
bias=self.conv_bias,
5155
groups=self.input_size,
56+
dilation=self.dilation,
5257
)
5358
)
5459
self.layernorms_conv.append(nn.LayerNorm(self.input_size))
@@ -63,6 +68,7 @@ def __init__(self, config):
6368
padding=self.d_conv - 1,
6469
bias=self.conv_bias,
6570
groups=self.hidden_size,
71+
dilation=self.dilation,
6672
)
6773
)
6874
self.layernorms_conv.append(nn.LayerNorm(self.hidden_size))
@@ -159,7 +165,10 @@ def __init__(
159165

160166
if self.residuals:
161167
self.residual_matrix = nn.ParameterList(
162-
[nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) for _ in range(self.num_layers)]
168+
[
169+
nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
170+
for _ in range(self.num_layers)
171+
]
163172
)
164173

165174
# First Conv1d layer uses input_size

mambular/configs/mambular_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class DefaultMambularConfig(BaseConfig):
2222
Dropout rate for regularization.
2323
d_conv : int, default=4
2424
Size of convolution over columns.
25+
dilation : int, default=1
26+
Dilation factor for the convolution.
2527
dt_rank : str, default="auto"
2628
Rank of the decision tree used in the model.
2729
d_state : int, default=128
@@ -76,6 +78,7 @@ class DefaultMambularConfig(BaseConfig):
7678
d_model: int = 64
7779
n_layers: int = 4
7880
d_conv: int = 4
81+
dilation: int = 1
7982
expand_factor: int = 2
8083
bias: bool = False
8184
dropout: float = 0.0

mambular/configs/tabularnn_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class DefaultTabulaRNNConfig(BaseConfig):
4848
Size of the feedforward network.
4949
d_conv : int, default=4
5050
Size of the convolutional layer for embedding features.
51+
dilation : int, default=1
52+
Dilation factor for the convolution.
5153
conv_bias : bool, default=True
5254
Whether to use bias in the convolutional layers.
5355
"""
@@ -78,4 +80,5 @@ class DefaultTabulaRNNConfig(BaseConfig):
7880
rnn_activation: str = "relu"
7981
dim_feedforward: int = 256
8082
d_conv: int = 4
83+
dilation: int = 1
8184
conv_bias: bool = True

0 commit comments

Comments
 (0)