Skip to content

Commit aa6b3be

Browse files
authored
Merge pull request #45 from basf/doc_fix
homepage update
2 parents 7284027 + 2463b01 commit aa6b3be

2 files changed

Lines changed: 106 additions & 3 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ examples/lightning_logs
171171

172172
docs/_build/doctrees/*
173173
docs/_build/html/*
174+
175+
dev/*

docs/homepage.md

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,39 @@ Mambular is a Python package that brings the power of Mamba architectures to tab
1212
- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision.
1313

1414

15+
## Models
16+
17+
| Model | Description |
18+
|---------------------|--------------------------------------------------------------------------------------------------|
19+
| `Mambular` | An advanced model using Mamba blocks specifically designed for various tabular data tasks. |
20+
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
21+
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
22+
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
23+
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
24+
25+
All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
26+
Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
27+
28+
29+
30+
## Documentation
31+
32+
You can find the Mamba-Tabular API documentation [here](https://mamba-tabular.readthedocs.io/en/latest/index.html).
33+
34+
## Installation
35+
36+
Install Mambular using pip:
37+
```sh
38+
pip install mambular
39+
```
40+
1541
## Preprocessing
1642

1743
Mambular simplifies the preprocessing stage of model development with a comprehensive set of techniques to prepare your data for Mamba architectures. Our preprocessing module is designed to be both powerful and easy to use, offering a variety of options to efficiently transform your tabular data.
1844

1945
### Data Type Detection and Transformation
2046

2147
Mambular automatically identifies the type of each feature in your dataset and applies the most appropriate transformations for numerical and categorical variables. This includes:
22-
2348
- **Ordinal Encoding**: Categorical features are seamlessly transformed into numerical values, preserving their inherent order and making them model-ready.
2449
- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality.
2550
- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts.
@@ -102,7 +127,8 @@ from mambular.models import MambularLSS
102127
model = MambularLSS(
103128
dropout=0.2,
104129
d_model=64,
105-
n_layers=8,
130+
n_layers=8,
131+
106132
)
107133

108134
# Fit the model to your data
@@ -117,10 +143,81 @@ model.fit(
117143

118144
```
119145

146+
147+
### Implement your own model:
148+
mambular allows users to easily integrate their custom models into the existing logic. Simply create a pytorch model and define its forward pass. Instead of inheriting from nn.Module, inherit from mambulars BaseModel. Each mambular model takse three arguments. The number of classes, e.g. = 1 for regression or = 2 for binary classification. For distributional regression, while this argument must be provided, it is determined automatically depending on the chosen distribution. Additionally, it takes two arguments directly passed from preprocessor. The cat_feature_info and num_feature_info for categorical and numerical feature information of e.g. the provided shape. Additionally, you can provide a config argument, which you can either implement similarly to the implemented configs, or simply use one of the Default Configs provided. A custom model could hence look just like this:
149+
150+
151+
1. First, define your config
152+
153+
```python
154+
from dataclasses import dataclass
155+
156+
@dataclass
157+
class MyConfig:
158+
lr: float = 1e-04
159+
lr_patience: int = 10
160+
weight_decay: float = 1e-06
161+
lr_factor: float = 0.1
162+
```
163+
164+
2. Second, define your model just as you would for a nn.Module. Simply define the architecture and the forward pass
165+
166+
```python
167+
from mambular.base_models import BaseModel
168+
import torch
169+
import torch.nn
170+
171+
class MyCustomModel(BaseModel):
172+
def __init__(
173+
self,
174+
cat_feature_info,
175+
num_feature_info,
176+
num_classes: int = 1,
177+
config=None,
178+
**kwargs,
179+
):
180+
super().__init__(**kwargs)
181+
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
182+
183+
input_dim = 0
184+
for feature_name, input_shape in num_feature_info.items():
185+
input_dim += input_shape
186+
for feature_name, input_shape in cat_feature_info.items():
187+
input_dim += 1
188+
189+
self.linear = nn.Linear(input_dim, num_classes)
190+
191+
def forward(self, num_features, cat_features):
192+
x = num_features + cat_features
193+
x = torch.cat(x, dim=1)
194+
195+
# Pass through linear layer
196+
output = self.linear(x)
197+
return output
198+
```
199+
200+
3. To leverage the mambular API, you can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
201+
202+
```python
203+
from mambular.models import SklearnBaseRegressor
204+
205+
class MyRegressor(SklearnBaseRegressor):
206+
def __init__(self, **kwargs):
207+
super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
208+
```
209+
210+
4. Subsequently, you can fit, evaluate and predict with your model just like with any other mambualr model.
211+
To achieve the same for classification or disrtibutional regression, instead of inheriting from the SklearnbaseRegressor, simply inherit from the SklearnBaseClassifier and SklearnBaseLSS.
212+
213+
```python
214+
regressor = MyRegressor(numerical_preprocessing="ple")
215+
regressor.fit(X_train, y_train, max_epochs=50)
216+
```
217+
120218
## Citation
121219

122220
If you find this project useful in your research, please consider cite:
123-
124221
```BibTeX
125222
@misc{2024,
126223
title={Mambular: Tabular Deep Learning with Mamba Architectures},
@@ -129,3 +226,7 @@ If you find this project useful in your research, please consider cite:
129226
year={2024}
130227
}
131228
```
229+
230+
## License
231+
232+
The entire codebase is under MIT license.

0 commit comments

Comments
 (0)