You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+68-15Lines changed: 68 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,21 +16,21 @@
16
16
</div>
17
17
18
18
<divstyle="text-align: center;">
19
-
<h1>Mambular: Tabular Deep Learning (with Mamba)</h1>
19
+
<h1>Mambular: Tabular Deep Learning</h1>
20
20
</div>
21
21
22
-
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
22
+
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
23
23
24
24
<h3> Table of Contents </h3>
25
25
26
26
-[🏃 Quickstart](#-quickstart)
27
27
-[📖 Introduction](#-introduction)
28
28
-[🤖 Models](#-models)
29
-
-[🏆 Results](#-results)
30
29
-[📚 Documentation](#-documentation)
31
30
-[🛠️ Installation](#️-installation)
32
31
-[🚀 Usage](#-usage)
33
32
-[💻 Implement Your Own Model](#-implement-your-own-model)
33
+
-[Custom Training](#custom-training)
34
34
-[🏷️ Citation](#️-citation)
35
35
-[License](#license)
36
36
@@ -53,18 +53,18 @@ Mambular is a Python package that brings the power of advanced deep learning arc
|`Mambular`| A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
59
-
|`TabM`| Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210)|
60
-
|`NODE`| Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312)|
61
-
|`FTTransformer`| A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
62
-
|`MLP`| A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
63
-
|`ResNet`| An adaptation of the ResNet architecture for tabular data applications. |
64
-
|`TabTransformer`| A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
65
-
|`MambaTab`| A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
66
-
|`TabulaRNN`| A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks introduced [here](https://arxiv.org/pdf/2411.17207). |
67
-
|`MambAttention`| A combination between Mamba and Transformers, similar to Jamba by [Lieber et al.](https://arxiv.org/abs/2403.19887). Not yet included in the benchmarks|
|`Mambular`| A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291).|
59
+
|`TabM`| Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210)|
60
+
|`NODE`| Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312)|
61
+
|`FTTransformer`| A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
62
+
|`MLP`| A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
63
+
|`ResNet`| An adaptation of the ResNet architecture for tabular data applications. |
64
+
|`TabTransformer`| A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
65
+
|`MambaTab`| A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
66
+
|`TabulaRNN`| A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207).|
67
+
|`MambAttention`| A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
68
68
69
69
70
70
@@ -145,6 +145,59 @@ preds = model.predict(X)
145
145
preds = model.predict_proba(X)
146
146
```
147
147
148
+
<h3> Hyperparameter Optimization</h3>
149
+
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
150
+
151
+
```python
152
+
from sklearn.model_selection import RandomizedSearchCV
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.
191
+
192
+
193
+
Or use the built-in bayesian hpo simply by running:
194
+
195
+
```python
196
+
best_params = model.optimize_hparams(X, y)
197
+
```
198
+
199
+
This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
200
+
148
201
149
202
<h2> ⚖️ Distributional Regression with MambularLSS </h2>
0 commit comments