Skip to content

Commit 3cdf998

Browse files
authored
Merge pull request #219 from basf/rdme_fix
Rdme fix
2 parents d155e22 + 3a769c1 commit 3cdf998

6 files changed

Lines changed: 365 additions & 460 deletions

File tree

README.md

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@
2121

2222
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
2323

24+
<h3>⚡ What's New ⚡</h3>
25+
<ul>
26+
<li>Individual preprocessing: preprocess each feature differently, use pre-trained models for categorical encoding</li>
27+
<li>Extract latent representations of tables</li>
28+
<li>Use embeddings as inputs</li>
29+
<li>Define custom training metrics</li>
30+
</ul>
31+
32+
33+
34+
2435
<h3> Table of Contents </h3>
2536

2637
- [🏃 Quickstart](#-quickstart)
@@ -30,7 +41,6 @@ Mambular is a Python library for tabular deep learning. It includes models that
3041
- [🛠️ Installation](#️-installation)
3142
- [🚀 Usage](#-usage)
3243
- [💻 Implement Your Own Model](#-implement-your-own-model)
33-
- [Custom Training](#custom-training)
3444
- [🏷️ Citation](#️-citation)
3545
- [License](#license)
3646

@@ -103,6 +113,7 @@ pip install mamba-ssm
103113
<h2> Preprocessing </h2>
104114

105115
Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
116+
Specify a default method, or a dictionary defining individual preprocessing methods for each feature.
106117

107118
<h3> Data Type Detection and Transformation </h3>
108119

@@ -116,6 +127,7 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t
116127
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
117128
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
118129
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.
130+
- **Pre-trained Encoding**: Use sentence transformers to encode categorical features.
119131

120132

121133

@@ -147,6 +159,28 @@ preds = model.predict(X)
147159
preds = model.predict_proba(X)
148160
```
149161

162+
Get latent representations for each feature:
163+
```python
164+
# simple encoding
165+
model.encode(X)
166+
```
167+
168+
Use unstructured data:
169+
```python
170+
# load pretrained models
171+
image_model = ...
172+
nlp_model = ...
173+
174+
# create embeddings
175+
img_embs = image_model.encode(images)
176+
txt_embs = nlp_model.encode(texts)
177+
178+
# fit model on tabular data and unstructured data
179+
model.fit(X_train, y_train, embeddings=[img_embs, txt_embs])
180+
```
181+
182+
183+
150184
<h3> Hyperparameter Optimization</h3>
151185
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
152186

@@ -222,9 +256,11 @@ MambularLSS allows you to model the full distribution of a response variable, no
222256
- **studentt**: For data with heavier tails, useful with small samples.
223257
- **negativebinom**: For over-dispersed count data.
224258
- **inversegamma**: Often used as a prior in Bayesian inference.
259+
- **johnsonsu**: Four parameter distribution defining location, scale, kurtosis and skewness.
225260
- **categorical**: For data with more than two categories.
226261
- **Quantile**: For quantile regression using the pinball loss.
227262

263+
228264
These distribution classes make MambularLSS versatile in modeling various data types and distributions.
229265

230266

@@ -269,13 +305,16 @@ Here's how you can implement a custom model with Mambular:
269305

270306
```python
271307
from dataclasses import dataclass
308+
from mambular.configs import BaseConfig
272309

273310
@dataclass
274-
class MyConfig:
311+
class MyConfig(BaseConfig):
275312
lr: float = 1e-04
276313
lr_patience: int = 10
277314
weight_decay: float = 1e-06
278-
lr_factor: float = 0.1
315+
n_layers: int = 4
316+
pooling_method:str = "avg
317+
279318
```
280319

281320
2. **Second, define your model:**
@@ -290,22 +329,32 @@ Here's how you can implement a custom model with Mambular:
290329
class MyCustomModel(BaseModel):
291330
def __init__(
292331
self,
293-
cat_feature_info,
294-
num_feature_info,
332+
feature_information: tuple,
295333
num_classes: int = 1,
296334
config=None,
297335
**kwargs,
298336
):
299-
super().__init__(**kwargs)
300-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
337+
super().__init__(**kwargs)
338+
self.save_hyperparameters(ignore=["feature_information"])
339+
self.returns_ensemble = False
340+
341+
# embedding layer
342+
self.embedding_layer = EmbeddingLayer(
343+
*feature_information,
344+
config=config,
345+
)
301346

302-
input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
347+
input_dim = np.sum(
348+
[len(info) * self.hparams.d_model for info in feature_information]
349+
)
303350

304351
self.linear = nn.Linear(input_dim, num_classes)
305352

306-
def forward(self, num_features, cat_features):
307-
x = num_features + cat_features
308-
x = torch.cat(x, dim=1)
353+
def forward(self, *data) -> torch.Tensor:
354+
x = self.embedding_layer(*data)
355+
B, S, D = x.shape
356+
x = x.reshape(B, S * D)
357+
309358

310359
# Pass through linear layer
311360
output = self.linear(x)
@@ -329,60 +378,11 @@ Here's how you can implement a custom model with Mambular:
329378
```python
330379
regressor = MyRegressor(numerical_preprocessing="ple")
331380
regressor.fit(X_train, y_train, max_epochs=50)
381+
382+
regressor.evaluate(X_test, y_test)
332383
```
333384

334-
# Custom Training
335-
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
336-
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.
337385

338-
```python
339-
import torch
340-
import torch.nn as nn
341-
import torch.optim as optim
342-
from mambular.base_models import Mambular
343-
from mambular.configs import DefaultMambularConfig
344-
345-
# Dummy data and configuration
346-
cat_feature_info = {
347-
"cat1": {
348-
"preprocessing": "imputer -> continuous_ordinal",
349-
"dimension": 1,
350-
"categories": 4,
351-
}
352-
} # Example categorical feature information
353-
num_feature_info = {
354-
"num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
355-
} # Example numerical feature information
356-
num_classes = 1
357-
config = DefaultMambularConfig() # Use the desired configuration
358-
359-
# Initialize model, loss function, and optimizer
360-
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
361-
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
362-
optimizer = optim.Adam(model.parameters(), lr=0.001)
363-
364-
# Example training loop
365-
for epoch in range(10): # Number of epochs
366-
model.train()
367-
optimizer.zero_grad()
368-
369-
# Dummy Data
370-
num_features = [torch.randn(32, 1) for _ in num_feature_info]
371-
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
372-
labels = torch.randn(32, num_classes)
373-
374-
# Forward pass
375-
outputs = model(num_features, cat_features)
376-
loss = criterion(outputs, labels)
377-
378-
# Backward pass and optimization
379-
loss.backward()
380-
optimizer.step()
381-
382-
# Print loss for monitoring
383-
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
384-
385-
```
386386

387387
# 🏷️ Citation
388388

mambular/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .tabm_config import DefaultTabMConfig
1111
from .tabtransformer_config import DefaultTabTransformerConfig
1212
from .tabularnn_config import DefaultTabulaRNNConfig
13+
from .base_config import BaseConfig
1314

1415
__all__ = [
1516
"DefaultFTTransformerConfig",
@@ -24,4 +25,5 @@
2425
"DefaultTabMConfig",
2526
"DefaultTabTransformerConfig",
2627
"DefaultTabulaRNNConfig",
28+
"BaseConfig"
2729
]

0 commit comments

Comments
 (0)