Skip to content

Commit 330c1a0

Browse files
authored
Merge pull request #102 from basf/rnn_branch
include tabularRNN
2 parents 91ab62c + 43d2758 commit 330c1a0

4 files changed

Lines changed: 495 additions & 0 deletions

File tree

mambular/base_models/tabularnn.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
import torch.nn as nn
3+
from ..arch_utils.mlp_utils import MLP
4+
from ..configs.tabularnn_config import DefaultTabulaRNNConfig
5+
from .basemodel import BaseModel
6+
from ..arch_utils.embedding_layer import EmbeddingLayer
7+
from ..arch_utils.normalization_layers import (
8+
RMSNorm,
9+
LayerNorm,
10+
LearnableLayerScaling,
11+
BatchNorm,
12+
InstanceNorm,
13+
GroupNorm,
14+
)
15+
16+
17+
class TabulaRNN(BaseModel):
18+
def __init__(
19+
self,
20+
cat_feature_info,
21+
num_feature_info,
22+
num_classes=1,
23+
config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(),
24+
**kwargs,
25+
):
26+
super().__init__(**kwargs)
27+
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
28+
29+
self.lr = self.hparams.get("lr", config.lr)
30+
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
31+
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
32+
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
33+
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
34+
self.cat_feature_info = cat_feature_info
35+
self.num_feature_info = num_feature_info
36+
37+
norm_layer = self.hparams.get("norm", config.norm)
38+
if norm_layer == "RMSNorm":
39+
self.norm_f = RMSNorm(
40+
self.hparams.get("dim_feedforward", config.dim_feedforward)
41+
)
42+
elif norm_layer == "LayerNorm":
43+
self.norm_f = LayerNorm(
44+
self.hparams.get("dim_feedforward", config.dim_feedforward)
45+
)
46+
elif norm_layer == "BatchNorm":
47+
self.norm_f = BatchNorm(
48+
self.hparams.get("dim_feedforward", config.dim_feedforward)
49+
)
50+
elif norm_layer == "InstanceNorm":
51+
self.norm_f = InstanceNorm(
52+
self.hparams.get("dim_feedforward", config.dim_feedforward)
53+
)
54+
elif norm_layer == "GroupNorm":
55+
self.norm_f = GroupNorm(
56+
1, self.hparams.get("dim_feedforward", config.dim_feedforward)
57+
)
58+
elif norm_layer == "LearnableLayerScaling":
59+
self.norm_f = LearnableLayerScaling(
60+
self.hparams.get("dim_feedforward", config.dim_feedforward)
61+
)
62+
else:
63+
self.norm_f = None
64+
65+
rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type]
66+
self.rnn = rnn_layer(
67+
input_size=self.hparams.get("d_model", config.d_model),
68+
hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward),
69+
num_layers=self.hparams.get("n_layers", config.n_layers),
70+
bidirectional=self.hparams.get("bidirectional", config.bidirectional),
71+
batch_first=True,
72+
dropout=self.hparams.get("rnn_dropout", config.rnn_dropout),
73+
bias=self.hparams.get("bias", config.bias),
74+
nonlinearity=(
75+
self.hparams.get("rnn_activation", config.rnn_activation)
76+
if config.model_type == "RNN"
77+
else None
78+
),
79+
)
80+
81+
self.embedding_layer = EmbeddingLayer(
82+
num_feature_info=num_feature_info,
83+
cat_feature_info=cat_feature_info,
84+
d_model=self.hparams.get("d_model", config.d_model),
85+
embedding_activation=self.hparams.get(
86+
"embedding_activation", config.embedding_activation
87+
),
88+
layer_norm_after_embedding=self.hparams.get(
89+
"layer_norm_after_embedding", config.layer_norm_after_embedding
90+
),
91+
use_cls=False,
92+
cls_position=-1,
93+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
94+
)
95+
96+
head_activation = self.hparams.get("head_activation", config.head_activation)
97+
98+
self.tabular_head = MLP(
99+
self.hparams.get("dim_feedforward", config.dim_feedforward),
100+
hidden_units_list=self.hparams.get(
101+
"head_layer_sizes", config.head_layer_sizes
102+
),
103+
dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
104+
use_skip_layers=self.hparams.get(
105+
"head_skip_layers", config.head_skip_layers
106+
),
107+
activation_fn=head_activation,
108+
use_batch_norm=self.hparams.get(
109+
"head_use_batch_norm", config.head_use_batch_norm
110+
),
111+
n_output_units=num_classes,
112+
)
113+
114+
self.linear = nn.Linear(config.d_model, config.dim_feedforward)
115+
116+
def forward(self, num_features, cat_features):
117+
"""
118+
Defines the forward pass of the model.
119+
120+
Parameters
121+
----------
122+
num_features : Tensor
123+
Tensor containing the numerical features.
124+
cat_features : Tensor
125+
Tensor containing the categorical features.
126+
127+
Returns
128+
-------
129+
Tensor
130+
The output predictions of the model.
131+
"""
132+
133+
x = self.embedding_layer(num_features, cat_features)
134+
# RNN forward pass
135+
out, _ = self.rnn(x)
136+
z = self.linear(torch.mean(x, dim=1))
137+
138+
if self.pooling_method == "avg":
139+
x = torch.mean(out, dim=1)
140+
elif self.pooling_method == "max":
141+
x, _ = torch.max(out, dim=1)
142+
elif self.pooling_method == "sum":
143+
x = torch.sum(out, dim=1)
144+
elif self.pooling_method == "last":
145+
x = x[:, -1, :]
146+
else:
147+
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
148+
x = x + z
149+
if self.norm_f is not None:
150+
x = self.norm_f(x)
151+
preds = self.tabular_head(x)
152+
153+
return preds
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dataclasses import dataclass
2+
import torch.nn as nn
3+
4+
5+
@dataclass
6+
class DefaultTabulaRNNConfig:
7+
"""
8+
Configuration class for the default TabulaRNN model with predefined hyperparameters.
9+
10+
Parameters
11+
----------
12+
lr : float, default=1e-04
13+
Learning rate for the optimizer.
14+
model_type : str, default="RNN"
15+
type of model, one of "RNN", "LSTM", "GRU"
16+
lr_patience : int, default=10
17+
Number of epochs with no improvement after which learning rate will be reduced.
18+
weight_decay : float, default=1e-06
19+
Weight decay (L2 penalty) for the optimizer.
20+
lr_factor : float, default=0.1
21+
Factor by which the learning rate will be reduced.
22+
d_model : int, default=64
23+
Dimensionality of the model.
24+
n_layers : int, default=8
25+
Number of layers in the transformer.
26+
norm : str, default="RMSNorm"
27+
Normalization method to be used.
28+
activation : callable, default=nn.SELU()
29+
Activation function for the transformer.
30+
embedding_activation : callable, default=nn.Identity()
31+
Activation function for numerical embeddings.
32+
head_layer_sizes : list, default=(128, 64, 32)
33+
Sizes of the layers in the head of the model.
34+
head_dropout : float, default=0.5
35+
Dropout rate for the head layers.
36+
head_skip_layers : bool, default=False
37+
Whether to skip layers in the head.
38+
head_activation : callable, default=nn.SELU()
39+
Activation function for the head layers.
40+
head_use_batch_norm : bool, default=False
41+
Whether to use batch normalization in the head layers.
42+
layer_norm_after_embedding : bool, default=False
43+
Whether to apply layer normalization after embedding.
44+
pooling_method : str, default="cls"
45+
Pooling method to be used ('cls', 'avg', etc.).
46+
norm_first : bool, default=False
47+
Whether to apply normalization before other operations in each transformer block.
48+
bias : bool, default=True
49+
Whether to use bias in the linear layers.
50+
rnn_activation : callable, default=nn.SELU()
51+
Activation function for the transformer layers.
52+
bidirectional : bool, default=False.
53+
Whether to process data bidirectionally
54+
cat_encoding : str, default="int"
55+
Encoding method for categorical features.
56+
"""
57+
58+
lr: float = 1e-04
59+
model_type: str = "RNN"
60+
lr_patience: int = 10
61+
weight_decay: float = 1e-06
62+
lr_factor: float = 0.1
63+
d_model: int = 128
64+
n_layers: int = 4
65+
rnn_dropout: float = 0.2
66+
norm: str = "RMSNorm"
67+
activation: callable = nn.SELU()
68+
embedding_activation: callable = nn.Identity()
69+
head_layer_sizes: list = ()
70+
head_dropout: float = 0.5
71+
head_skip_layers: bool = False
72+
head_activation: callable = nn.SELU()
73+
head_use_batch_norm: bool = False
74+
layer_norm_after_embedding: bool = False
75+
pooling_method: str = "avg"
76+
norm_first: bool = False
77+
bias: bool = True
78+
rnn_activation: str = "relu"
79+
layer_norm_eps: float = 1e-05
80+
dim_feedforward: int = 256
81+
numerical_embedding: str = "ple"
82+
bidirectional: bool = False
83+
cat_encoding: str = "int"

mambular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
from .mambatab import MambaTabClassifier, MambaTabRegressor, MambaTabLSS
19+
from .tabularnn import TabulaRNNClassifier, TabulaRNNRegressor, TabulaRNNLSS
1920

2021

2122
__all__ = [
@@ -40,4 +41,7 @@
4041
"MambaTabRegressor",
4142
"MambaTabClassifier",
4243
"MambaTabLSS",
44+
"TabulaRNNClassifier",
45+
"TabulaRNNRegressor",
46+
"TabulaRNNLSS",
4347
]

0 commit comments

Comments
 (0)