Skip to content

Commit 978c49e

Browse files
authored
Merge pull request #166 from basf/gridsearch
add prepro-args to sklearn hpo
2 parents e5c2892 + 99a309b commit 978c49e

5 files changed

Lines changed: 123 additions & 27 deletions

File tree

README.md

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616
</div>
1717

1818
<div style="text-align: center;">
19-
<h1>Mambular: Tabular Deep Learning (with Mamba)</h1>
19+
<h1>Mambular: Tabular Deep Learning</h1>
2020
</div>
2121

22-
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
22+
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
2323

2424
<h3> Table of Contents </h3>
2525

2626
- [🏃 Quickstart](#-quickstart)
2727
- [📖 Introduction](#-introduction)
2828
- [🤖 Models](#-models)
29-
- [🏆 Results](#-results)
3029
- [📚 Documentation](#-documentation)
3130
- [🛠️ Installation](#️-installation)
3231
- [🚀 Usage](#-usage)
3332
- [💻 Implement Your Own Model](#-implement-your-own-model)
33+
- [Custom Training](#custom-training)
3434
- [🏷️ Citation](#️-citation)
3535
- [License](#license)
3636

@@ -53,18 +53,18 @@ Mambular is a Python package that brings the power of advanced deep learning arc
5353

5454
# 🤖 Models
5555

56-
| Model | Description |
57-
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
58-
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
59-
| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
60-
| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
61-
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
62-
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
63-
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
64-
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
65-
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
66-
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks introduced [here](https://arxiv.org/pdf/2411.17207). |
67-
| `MambAttention` | A combination between Mamba and Transformers, similar to Jamba by [Lieber et al.](https://arxiv.org/abs/2403.19887). Not yet included in the benchmarks |
56+
| Model | Description |
57+
| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
58+
| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). |
59+
| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
60+
| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
61+
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
62+
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
63+
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
64+
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
65+
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
66+
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
67+
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
6868

6969

7070

@@ -145,6 +145,59 @@ preds = model.predict(X)
145145
preds = model.predict_proba(X)
146146
```
147147

148+
<h3> Hyperparameter Optimization</h3>
149+
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
150+
151+
```python
152+
from sklearn.model_selection import RandomizedSearchCV
153+
154+
param_dist = {
155+
'd_model': randint(32, 128),
156+
'n_layers': randint(2, 10),
157+
'lr': uniform(1e-5, 1e-3)
158+
}
159+
160+
random_search = RandomizedSearchCV(
161+
estimator=model,
162+
param_distributions=param_dist,
163+
n_iter=50, # Number of parameter settings sampled
164+
cv=5, # 5-fold cross-validation
165+
scoring='accuracy', # Metric to optimize
166+
random_state=42
167+
)
168+
169+
fit_params = {"max_epochs":5, "rebuild":False}
170+
171+
# Fit the model
172+
random_search.fit(X, y, **fit_params)
173+
174+
# Best parameters and score
175+
print("Best Parameters:", random_search.best_params_)
176+
print("Best Score:", random_search.best_score_)
177+
```
178+
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
179+
```python
180+
param_dist = {
181+
'd_model': randint(32, 128),
182+
'n_layers': randint(2, 10),
183+
'lr': uniform(1e-5, 1e-3),
184+
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
185+
}
186+
187+
```
188+
189+
190+
Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.
191+
192+
193+
Or use the built-in bayesian hpo simply by running:
194+
195+
```python
196+
best_params = model.optimize_hparams(X, y)
197+
```
198+
199+
This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
200+
148201

149202
<h2> ⚖️ Distributional Regression with MambularLSS </h2>
150203

mambular/models/sklearn_base_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_params(self, deep=True):
8787

8888
if deep:
8989
preprocessor_params = {
90-
"preprocessor__" + key: value
90+
"prepro__" + key: value
9191
for key, value in self.preprocessor.get_params().items()
9292
}
9393
params.update(preprocessor_params)
@@ -109,12 +109,12 @@ def set_params(self, **parameters):
109109
Estimator instance.
110110
"""
111111
config_params = {
112-
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
112+
k: v for k, v in parameters.items() if not k.startswith("prepro__")
113113
}
114114
preprocessor_params = {
115115
k.split("__")[1]: v
116116
for k, v in parameters.items()
117-
if k.startswith("preprocessor__")
117+
if k.startswith("prepro__")
118118
}
119119

120120
if config_params:

mambular/models/sklearn_base_lss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_params(self, deep=True):
109109

110110
if deep:
111111
preprocessor_params = {
112-
"preprocessor__" + key: value
112+
"prepro__" + key: value
113113
for key, value in self.preprocessor.get_params().items()
114114
}
115115
params.update(preprocessor_params)
@@ -131,12 +131,12 @@ def set_params(self, **parameters):
131131
Estimator instance.
132132
"""
133133
config_params = {
134-
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
134+
k: v for k, v in parameters.items() if not k.startswith("prepro__")
135135
}
136136
preprocessor_params = {
137137
k.split("__")[1]: v
138138
for k, v in parameters.items()
139-
if k.startswith("preprocessor__")
139+
if k.startswith("prepro__")
140140
}
141141

142142
if config_params:

mambular/models/sklearn_base_regressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_params(self, deep=True):
8888

8989
if deep:
9090
preprocessor_params = {
91-
"preprocessor__" + key: value
91+
"prepro__" + key: value
9292
for key, value in self.preprocessor.get_params().items()
9393
}
9494
params.update(preprocessor_params)
@@ -110,12 +110,12 @@ def set_params(self, **parameters):
110110
Estimator instance.
111111
"""
112112
config_params = {
113-
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
113+
k: v for k, v in parameters.items() if not k.startswith("prepro__")
114114
}
115115
preprocessor_params = {
116116
k.split("__")[1]: v
117117
for k, v in parameters.items()
118-
if k.startswith("preprocessor__")
118+
if k.startswith("prepro__")
119119
}
120120

121121
if config_params:

mambular/preprocessing/preprocessor.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,48 @@ def __init__(
131131
self.degree = degree
132132
self.n_knots = knots
133133

134+
def get_params(self, deep=True):
135+
"""
136+
Get parameters for the preprocessor.
137+
138+
Parameters
139+
----------
140+
deep : bool, default=True
141+
If True, will return parameters of subobjects that are estimators.
142+
143+
Returns
144+
-------
145+
params : dict
146+
Parameter names mapped to their values.
147+
"""
148+
params = {
149+
"n_bins": self.n_bins,
150+
"numerical_preprocessing": self.numerical_preprocessing,
151+
"categorical_preprocessing": self.categorical_preprocessing,
152+
"use_decision_tree_bins": self.use_decision_tree_bins,
153+
"binning_strategy": self.binning_strategy,
154+
"task": self.task,
155+
"cat_cutoff": self.cat_cutoff,
156+
"treat_all_integers_as_numerical": self.treat_all_integers_as_numerical,
157+
"degree": self.degree,
158+
"knots": self.n_knots,
159+
}
160+
return params
161+
134162
def set_params(self, **params):
163+
"""
164+
Set parameters for the preprocessor.
165+
166+
Parameters
167+
----------
168+
**params : dict
169+
Parameter names mapped to their new values.
170+
171+
Returns
172+
-------
173+
self : object
174+
Preprocessor instance.
175+
"""
135176
for key, value in params.items():
136177
setattr(self, key, value)
137178
return self
@@ -222,9 +263,11 @@ def fit(self, X, y=None):
222263
(
223264
"discretizer",
224265
KBinsDiscretizer(
225-
n_bins=bins
226-
if isinstance(bins, int)
227-
else len(bins) - 1,
266+
n_bins=(
267+
bins
268+
if isinstance(bins, int)
269+
else len(bins) - 1
270+
),
228271
encode="ordinal",
229272
strategy=self.binning_strategy,
230273
subsample=200_000 if len(X) > 200_000 else None,

0 commit comments

Comments
 (0)