|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from ..arch_utils.mlp_utils import MLP |
| 4 | +from ..configs.tabularnn_config import DefaultTabulaRNNConfig |
| 5 | +from .basemodel import BaseModel |
| 6 | +from ..arch_utils.embedding_layer import EmbeddingLayer |
| 7 | +from ..arch_utils.normalization_layers import ( |
| 8 | + RMSNorm, |
| 9 | + LayerNorm, |
| 10 | + LearnableLayerScaling, |
| 11 | + BatchNorm, |
| 12 | + InstanceNorm, |
| 13 | + GroupNorm, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +class TabulaRNN(BaseModel): |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + cat_feature_info, |
| 21 | + num_feature_info, |
| 22 | + num_classes=1, |
| 23 | + config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), |
| 24 | + **kwargs, |
| 25 | + ): |
| 26 | + super().__init__(**kwargs) |
| 27 | + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) |
| 28 | + |
| 29 | + self.lr = self.hparams.get("lr", config.lr) |
| 30 | + self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) |
| 31 | + self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) |
| 32 | + self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) |
| 33 | + self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) |
| 34 | + self.cat_feature_info = cat_feature_info |
| 35 | + self.num_feature_info = num_feature_info |
| 36 | + |
| 37 | + norm_layer = self.hparams.get("norm", config.norm) |
| 38 | + if norm_layer == "RMSNorm": |
| 39 | + self.norm_f = RMSNorm( |
| 40 | + self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 41 | + ) |
| 42 | + elif norm_layer == "LayerNorm": |
| 43 | + self.norm_f = LayerNorm( |
| 44 | + self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 45 | + ) |
| 46 | + elif norm_layer == "BatchNorm": |
| 47 | + self.norm_f = BatchNorm( |
| 48 | + self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 49 | + ) |
| 50 | + elif norm_layer == "InstanceNorm": |
| 51 | + self.norm_f = InstanceNorm( |
| 52 | + self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 53 | + ) |
| 54 | + elif norm_layer == "GroupNorm": |
| 55 | + self.norm_f = GroupNorm( |
| 56 | + 1, self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 57 | + ) |
| 58 | + elif norm_layer == "LearnableLayerScaling": |
| 59 | + self.norm_f = LearnableLayerScaling( |
| 60 | + self.hparams.get("dim_feedforward", config.dim_feedforward) |
| 61 | + ) |
| 62 | + else: |
| 63 | + self.norm_f = None |
| 64 | + |
| 65 | + rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type] |
| 66 | + self.rnn = rnn_layer( |
| 67 | + input_size=self.hparams.get("d_model", config.d_model), |
| 68 | + hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward), |
| 69 | + num_layers=self.hparams.get("n_layers", config.n_layers), |
| 70 | + bidirectional=self.hparams.get("bidirectional", config.bidirectional), |
| 71 | + batch_first=True, |
| 72 | + dropout=self.hparams.get("rnn_dropout", config.rnn_dropout), |
| 73 | + bias=self.hparams.get("bias", config.bias), |
| 74 | + nonlinearity=( |
| 75 | + self.hparams.get("rnn_activation", config.rnn_activation) |
| 76 | + if config.model_type == "RNN" |
| 77 | + else None |
| 78 | + ), |
| 79 | + ) |
| 80 | + |
| 81 | + self.embedding_layer = EmbeddingLayer( |
| 82 | + num_feature_info=num_feature_info, |
| 83 | + cat_feature_info=cat_feature_info, |
| 84 | + d_model=self.hparams.get("d_model", config.d_model), |
| 85 | + embedding_activation=self.hparams.get( |
| 86 | + "embedding_activation", config.embedding_activation |
| 87 | + ), |
| 88 | + layer_norm_after_embedding=self.hparams.get( |
| 89 | + "layer_norm_after_embedding", config.layer_norm_after_embedding |
| 90 | + ), |
| 91 | + use_cls=False, |
| 92 | + cls_position=-1, |
| 93 | + cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), |
| 94 | + ) |
| 95 | + |
| 96 | + head_activation = self.hparams.get("head_activation", config.head_activation) |
| 97 | + |
| 98 | + self.tabular_head = MLP( |
| 99 | + self.hparams.get("dim_feedforward", config.dim_feedforward), |
| 100 | + hidden_units_list=self.hparams.get( |
| 101 | + "head_layer_sizes", config.head_layer_sizes |
| 102 | + ), |
| 103 | + dropout_rate=self.hparams.get("head_dropout", config.head_dropout), |
| 104 | + use_skip_layers=self.hparams.get( |
| 105 | + "head_skip_layers", config.head_skip_layers |
| 106 | + ), |
| 107 | + activation_fn=head_activation, |
| 108 | + use_batch_norm=self.hparams.get( |
| 109 | + "head_use_batch_norm", config.head_use_batch_norm |
| 110 | + ), |
| 111 | + n_output_units=num_classes, |
| 112 | + ) |
| 113 | + |
| 114 | + self.linear = nn.Linear(config.d_model, config.dim_feedforward) |
| 115 | + |
| 116 | + def forward(self, num_features, cat_features): |
| 117 | + """ |
| 118 | + Defines the forward pass of the model. |
| 119 | +
|
| 120 | + Parameters |
| 121 | + ---------- |
| 122 | + num_features : Tensor |
| 123 | + Tensor containing the numerical features. |
| 124 | + cat_features : Tensor |
| 125 | + Tensor containing the categorical features. |
| 126 | +
|
| 127 | + Returns |
| 128 | + ------- |
| 129 | + Tensor |
| 130 | + The output predictions of the model. |
| 131 | + """ |
| 132 | + |
| 133 | + x = self.embedding_layer(num_features, cat_features) |
| 134 | + # RNN forward pass |
| 135 | + out, _ = self.rnn(x) |
| 136 | + z = self.linear(torch.mean(x, dim=1)) |
| 137 | + |
| 138 | + if self.pooling_method == "avg": |
| 139 | + x = torch.mean(out, dim=1) |
| 140 | + elif self.pooling_method == "max": |
| 141 | + x, _ = torch.max(out, dim=1) |
| 142 | + elif self.pooling_method == "sum": |
| 143 | + x = torch.sum(out, dim=1) |
| 144 | + elif self.pooling_method == "last": |
| 145 | + x = x[:, -1, :] |
| 146 | + else: |
| 147 | + raise ValueError(f"Invalid pooling method: {self.pooling_method}") |
| 148 | + x = x + z |
| 149 | + if self.norm_f is not None: |
| 150 | + x = self.norm_f(x) |
| 151 | + preds = self.tabular_head(x) |
| 152 | + |
| 153 | + return preds |
0 commit comments