Skip to content

Commit b8bc5e9

Browse files
committed
restructure configs to create parent config-class
1 parent a2c7845 commit b8bc5e9

13 files changed

Lines changed: 120 additions & 501 deletions

mambular/configs/base_config.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dataclasses import dataclass, field
2+
from collections.abc import Callable
3+
import torch.nn as nn
4+
5+
6+
@dataclass
7+
class BaseConfig:
8+
"""
9+
Base configuration class with shared hyperparameters for models.
10+
11+
This configuration class provides common hyperparameters for optimization,
12+
embeddings, and categorical encoding, which can be inherited by specific
13+
model configurations.
14+
15+
Parameters
16+
----------
17+
lr : float, default=1e-04
18+
Learning rate for the optimizer.
19+
lr_patience : int, default=10
20+
Number of epochs with no improvement before reducing the learning rate.
21+
weight_decay : float, default=1e-06
22+
L2 regularization parameter for weight decay in the optimizer.
23+
lr_factor : float, default=0.1
24+
Factor by which the learning rate is reduced when patience is exceeded.
25+
activation : Callable, default=nn.ReLU()
26+
Activation function to use in the model's layers.
27+
cat_encoding : str, default="int"
28+
Method for encoding categorical features ('int', 'one-hot', or 'linear').
29+
30+
Embedding Parameters
31+
--------------------
32+
use_embeddings : bool, default=False
33+
Whether to use embeddings for categorical or numerical features.
34+
embedding_activation : Callable, default=nn.Identity()
35+
Activation function applied to embeddings.
36+
embedding_type : str, default="linear"
37+
Type of embedding to use ('linear', 'plr', etc.).
38+
embedding_bias : bool, default=False
39+
Whether to use bias in embedding layers.
40+
layer_norm_after_embedding : bool, default=False
41+
Whether to apply layer normalization after embedding layers.
42+
d_model : int, default=32
43+
Dimensionality of embeddings or model representations.
44+
plr_lite : bool, default=False
45+
Whether to use a lightweight version of Piecewise Linear Regression (PLR).
46+
n_frequencies : int, default=48
47+
Number of frequency components for embeddings.
48+
frequencies_init_scale : float, default=0.01
49+
Initial scale for frequency components in embeddings.
50+
embedding_projection : bool, default=True
51+
Whether to apply a projection layer after embeddings.
52+
53+
Notes
54+
-----
55+
- This base class is meant to be inherited by other configurations.
56+
- Provides default values that can be overridden in derived configurations.
57+
58+
"""
59+
60+
# Training Parameters
61+
lr: float = 1e-04
62+
lr_patience: int = 10
63+
weight_decay: float = 1e-06
64+
lr_factor: float = 0.1
65+
66+
# Embedding Parameters
67+
use_embeddings: bool = False
68+
embedding_activation: Callable = nn.Identity() # noqa: RUF009
69+
embedding_type: str = "linear"
70+
embedding_bias: bool = False
71+
layer_norm_after_embedding: bool = False
72+
d_model: int = 32
73+
plr_lite: bool = False
74+
n_frequencies: int = 48
75+
frequencies_init_scale: float = 0.01
76+
embedding_projection: bool = True
77+
78+
# Architecture Parameters
79+
batch_norm: bool = False
80+
layer_norm: bool = False
81+
layer_norm_eps: float = 1e-05
82+
activation: Callable = nn.ReLU() # noqa: RUF009
83+
cat_encoding: str = "int"

mambular/configs/fttransformer_config.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass, field
3-
43
import torch.nn as nn
5-
64
from ..arch_utils.transformer_utils import ReGLU
5+
from .base_config import BaseConfig
76

87

98
@dataclass
10-
class DefaultFTTransformerConfig:
9+
class DefaultFTTransformerConfig(BaseConfig):
1110
"""Configuration class for the FT Transformer model with predefined hyperparameters.
1211
1312
Parameters
1413
----------
15-
lr : float, default=1e-04
16-
Learning rate for the optimizer.
17-
lr_patience : int, default=10
18-
Number of epochs with no improvement after which the learning rate will be reduced.
19-
weight_decay : float, default=1e-06
20-
Weight decay (L2 regularization) for the optimizer.
21-
lr_factor : float, default=0.1
22-
Factor by which the learning rate will be reduced.
2314
d_model : int, default=128
2415
Dimensionality of the transformer model.
2516
n_layers : int, default=4
@@ -44,20 +35,6 @@ class DefaultFTTransformerConfig:
4435
Whether to apply normalization before other operations in each transformer block.
4536
bias : bool, default=True
4637
Whether to use bias in linear layers.
47-
embedding_activation : callable, default=nn.Identity()
48-
Activation function for embeddings.
49-
embedding_type : str, default="linear"
50-
Type of embedding to use ('linear', 'plr', etc.).
51-
plr_lite : bool, default=False
52-
Whether to use a lightweight version of Piecewise Linear Regression (PLR).
53-
n_frequencies : int, default=48
54-
Number of frequencies for PLR embeddings.
55-
frequencies_init_scale : float, default=0.01
56-
Initial scale for frequency parameters in embeddings.
57-
embedding_bias : bool, default=False
58-
Whether to use bias in embedding layers.
59-
layer_norm_after_embedding : bool, default=False
60-
Whether to apply layer normalization after embedding layers.
6138
head_layer_sizes : list, default=()
6239
Sizes of the fully connected layers in the model's head.
6340
head_dropout : float, default=0.5
@@ -76,12 +53,6 @@ class DefaultFTTransformerConfig:
7653
Method for encoding categorical features ('int', 'one-hot', or 'linear').
7754
"""
7855

79-
# Optimizer Parameters
80-
lr: float = 1e-04
81-
lr_patience: int = 10
82-
weight_decay: float = 1e-06
83-
lr_factor: float = 0.1
84-
8556
# Architecture Parameters
8657
d_model: int = 128
8758
n_layers: int = 4
@@ -96,15 +67,6 @@ class DefaultFTTransformerConfig:
9667
norm_first: bool = False
9768
bias: bool = True
9869

99-
# Embedding Parameters
100-
embedding_activation: Callable = nn.Identity() # noqa: RUF009
101-
embedding_type: str = "linear"
102-
plr_lite: bool = False
103-
n_frequencies: int = 48
104-
frequencies_init_scale: float = 0.01
105-
embedding_bias: bool = False
106-
layer_norm_after_embedding: bool = False
107-
10870
# Head Parameters
10971
head_layer_sizes: list = field(default_factory=list)
11072
head_dropout: float = 0.5

mambular/configs/mambatab_config.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass, field
3-
43
import torch.nn as nn
4+
from .base_config import BaseConfig
55

66

77
@dataclass
8-
class DefaultMambaTabConfig:
8+
class DefaultMambaTabConfig(BaseConfig):
99
"""Configuration class for the Default MambaTab model with predefined hyperparameters.
1010
1111
Parameters
1212
----------
13-
lr : float, default=1e-04
14-
Learning rate for the optimizer.
15-
lr_patience : int, default=10
16-
Number of epochs with no improvement after which the learning rate will be reduced.
17-
weight_decay : float, default=1e-06
18-
Weight decay (L2 regularization) for the optimizer.
19-
lr_factor : float, default=0.1
20-
Factor by which the learning rate will be reduced.
2113
d_model : int, default=64
2214
Dimensionality of the model.
2315
n_layers : int, default=1
@@ -50,18 +42,6 @@ class DefaultMambaTabConfig:
5042
Activation function for the model.
5143
axis : int, default=1
5244
Axis along which operations are applied, if applicable.
53-
num_embedding_activation : callable, default=nn.ReLU()
54-
Activation function for numerical embeddings.
55-
embedding_type : str, default="linear"
56-
Type of embedding to use ('linear', etc.).
57-
embedding_bias : bool, default=False
58-
Whether to use bias in the embedding layers.
59-
plr_lite : bool, default=False
60-
Whether to use a lightweight version of Piecewise Linear Regression (PLR).
61-
n_frequencies : int, default=48
62-
Number of frequencies for PLR embeddings.
63-
frequencies_init_scale : float, default=0.01
64-
Initial scale for frequency parameters in embeddings.
6545
head_layer_sizes : list, default=()
6646
Sizes of the fully connected layers in the model's head.
6747
head_dropout : float, default=0.0
@@ -82,12 +62,6 @@ class DefaultMambaTabConfig:
8262
Whether to process data bidirectionally.
8363
"""
8464

85-
# Optimizer Parameters
86-
lr: float = 1e-04
87-
lr_patience: int = 10
88-
weight_decay: float = 1e-06
89-
lr_factor: float = 0.1
90-
9165
# Architecture Parameters
9266
d_model: int = 64
9367
n_layers: int = 1
@@ -106,14 +80,6 @@ class DefaultMambaTabConfig:
10680
activation: Callable = nn.ReLU() # noqa: RUF009
10781
axis: int = 1
10882

109-
# Embedding Parameters
110-
num_embedding_activation: Callable = nn.ReLU() # noqa: RUF009
111-
embedding_type: str = "linear"
112-
embedding_bias: bool = False
113-
plr_lite: bool = False
114-
n_frequencies: int = 48
115-
frequencies_init_scale: float = 0.01
116-
11783
# Head Parameters
11884
head_layer_sizes: list = field(default_factory=list)
11985
head_dropout: float = 0.0

mambular/configs/mambattention_config.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass, field
3-
43
import torch.nn as nn
4+
from .base_config import BaseConfig
55

66

77
@dataclass
8-
class DefaultMambAttentionConfig:
8+
class DefaultMambAttentionConfig(BaseConfig):
99
"""Configuration class for the Default Mambular Attention model with predefined hyperparameters.
1010
1111
Parameters
1212
----------
13-
lr : float, default=1e-04
14-
Learning rate for the optimizer.
15-
lr_patience : int, default=10
16-
Number of epochs with no improvement after which learning rate will be reduced.
17-
weight_decay : float, default=1e-06
18-
Weight decay (L2 penalty) for the optimizer.
19-
lr_factor : float, default=0.1
20-
Factor by which the learning rate will be reduced.
2113
d_model : int, default=64
2214
Dimensionality of the model.
2315
n_layers : int, default=4
@@ -58,22 +50,6 @@ class DefaultMambAttentionConfig:
5850
Type of normalization used in the model.
5951
activation : callable, default=nn.SiLU()
6052
Activation function for the model.
61-
layer_norm_eps : float, default=1e-05
62-
Epsilon value for layer normalization.
63-
num_embedding_activation : callable, default=nn.ReLU()
64-
Activation function for numerical embeddings.
65-
embedding_type : str, default="linear"
66-
Type of embedding to use ('linear', etc.).
67-
embedding_bias : bool, default=False
68-
Whether to use bias in the embedding layers.
69-
plr_lite : bool, default=False
70-
Whether to use a lightweight version of Piecewise Linear Regression (PLR).
71-
n_frequencies : int, default=48
72-
Number of frequencies for PLR embeddings.
73-
frequencies_init_scale : float, default=0.01
74-
Initial scale for frequency parameters in embeddings.
75-
layer_norm_after_embedding : bool, default=False
76-
Whether to apply layer normalization after embedding layers.
7753
head_layer_sizes : list, default=()
7854
Sizes of the fully connected layers in the model's head.
7955
head_dropout : float, default=0.5
@@ -106,12 +82,6 @@ class DefaultMambAttentionConfig:
10682
Number of attention layers in the model.
10783
"""
10884

109-
# Optimizer Parameters
110-
lr: float = 1e-04
111-
lr_patience: int = 10
112-
weight_decay: float = 1e-06
113-
lr_factor: float = 0.1
114-
11585
# Architecture Parameters
11686
d_model: int = 64
11787
n_layers: int = 4
@@ -133,16 +103,6 @@ class DefaultMambAttentionConfig:
133103
dt_init_floor: float = 1e-04
134104
norm: str = "LayerNorm"
135105
activation: Callable = nn.SiLU() # noqa: RUF009
136-
layer_norm_eps: float = 1e-05
137-
138-
# Embedding Parameters
139-
num_embedding_activation: Callable = nn.ReLU() # noqa: RUF009
140-
embedding_type: str = "linear"
141-
embedding_bias: bool = False
142-
plr_lite: bool = False
143-
n_frequencies: int = 48
144-
frequencies_init_scale: float = 0.01
145-
layer_norm_after_embedding: bool = False
146106

147107
# Head Parameters
148108
head_layer_sizes: list = field(default_factory=list)

0 commit comments

Comments
 (0)