Skip to content

Commit 30fc1fb

Browse files
authored
Merge pull request #38 from basf/restructure
adapt readme and inits
2 parents fe5f279 + 03e345b commit 30fc1fb

3 files changed

Lines changed: 74 additions & 0 deletions

File tree

README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ Mambular is a Python package that brings the power of Mamba architectures to tab
2828
- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn.
2929
- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision.
3030

31+
32+
## Models
33+
34+
| Model | Description |
35+
|---------------------|--------------------------------------------------------------------------------------------------|
36+
| `Mambular` | An advanced model using Mamba blocks specifically designed for various tabular data tasks. |
37+
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
38+
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
39+
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
40+
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
41+
42+
All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
43+
Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
44+
45+
46+
3147
## Documentation
3248

3349
You can find the Mamba-Tabular API documentation [here](https://mamba-tabular.readthedocs.io/en/latest/index.html).
@@ -144,6 +160,56 @@ model.fit(
144160

145161
```
146162

163+
164+
### Implement your own model:
165+
mambular allows users to easily integrate their custom models into the existing logic. Simply create a pytorch model and define its forward pass. Instead of inheriting from nn.Module, inherit from mambulars BaseModel. Each mambular model takse three arguments. The number of classes, e.g. = 1 for regression or = 2 for binary classification. For distributional regression, while this argument must be provided, it is determined automatically depending on the chosen distribution. Additionally, it takes two arguments directly passed from preprocessor. The cat_feature_info and num_feature_info for categorical and numerical feature information of e.g. the provided shape. Additionally, you can provide a config argument, which you can either use simialr to the implemented models, or leave empty as shown below. A custom model could hence look just like this:
166+
167+
168+
```python
169+
from mambular.base_models import BaseModel
170+
171+
class MyCustomModel(BaseModel):
172+
def __init__(
173+
self,
174+
cat_feature_info,
175+
num_feature_info,
176+
num_classes: int = 1,
177+
config=None,
178+
**kwargs,
179+
):
180+
super().__init__(**kwargs)
181+
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
182+
183+
input_dim = 0
184+
for feature_name, input_shape in num_feature_info.items():
185+
input_dim += input_shape
186+
for feature_name, input_shape in cat_feature_info.items():
187+
input_dim += 1
188+
189+
self.linear = nn.Linear(input_dim, num_classes)
190+
191+
def forward(self, num_features, cat_features):
192+
x = num_features + cat_features
193+
x = torch.cat(x, dim=1)
194+
195+
# Pass through linear layer
196+
output = self.linear(x)
197+
return output
198+
```
199+
200+
To leverage the mambular API, you can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
201+
202+
```python
203+
from mambular.models import SklearnBaseRegressor
204+
205+
class MyRegressor(SklearnBaseRegressor):
206+
def __init__(self, **kwargs):
207+
super().__init__(model=MyCustomModel, config=None, **kwargs)
208+
```
209+
210+
Subsequently, you can fit, evaluate and predict with your model just like with any other mambualr model.
211+
To achieve the same for classification or disrtibutional regression, instead of inheriting from the SklearnbaseRegressor, simply inherit from the SklearnBaseClassifier and SklearnBaseLSS.
212+
147213
## Citation
148214

149215
If you find this project useful in your research, please consider cite:

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .mlp import MLP
55
from .tabtransformer import TabTransformer
66
from .resnet import ResNet
7+
from .basemodel import BaseModel
78

89
__all__ = [
910
"TaskModel",
@@ -12,4 +13,5 @@
1213
"FTTransformer",
1314
"TabTransformer",
1415
"MLP",
16+
"BaseModel",
1517
]

mambular/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
TabTransformerLSS,
1212
)
1313
from .resnet import ResNetClassifier, ResNetRegressor, ResNetLSS
14+
from .sklearn_base_classifier import SklearnBaseClassifier
15+
from .sklearn_base_lss import SklearnBaseLSS
16+
from .sklearn_base_regressor import SklearnBaseRegressor
1417

1518

1619
__all__ = [
@@ -29,4 +32,7 @@
2932
"ResNetClassifier",
3033
"ResNetRegressor",
3134
"ResNetLSS",
35+
"SklearnBaseClassifier",
36+
"SklearnBaseLSS",
37+
"SklearnBaseRegressor",
3238
]

0 commit comments

Comments
 (0)