Skip to content

Commit 53de1df

Browse files
authored
Merge pull request #57 from basf/develop
new version release 0.1.5
2 parents 22cac8e + 150fd17 commit 53de1df

234 files changed

Lines changed: 5676 additions & 36296 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: PyPI Builder and Releaser
2+
3+
on:
4+
push:
5+
tags:
6+
- "v*.*.*"
7+
8+
jobs:
9+
release:
10+
name: Publishes release candidate to PyPI
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout code
15+
uses: actions/checkout@v3
16+
with:
17+
fetch-depth: 0
18+
19+
- name: Set up Python
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: "3.8"
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install setuptools wheel twine
28+
29+
- name: Build and publish package
30+
env:
31+
TWINE_USERNAME: __token__
32+
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
33+
run: |
34+
python setup.py sdist bdist_wheel
35+
twine upload dist/*

.github/workflows/publish.yml

Lines changed: 0 additions & 46 deletions
This file was deleted.

.gitignore

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,15 @@ cython_debug/
161161
.DS_Store
162162

163163
dist/
164-
docs/_build/*
164+
165+
# pkl files
166+
*.pkl
167+
168+
# logs and checkpoints
169+
examples/lightning_logs
170+
*.ckpt
171+
172+
docs/_build/doctrees/*
173+
docs/_build/html/*
174+
175+
dev/*

README.md

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,34 @@
1717

1818
# Mambular: Tabular Deep Learning with Mamba Architectures
1919

20-
Mambular is a Python package that brings the power of Mamba architectures to tabular data, offering a suite of deep learning models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
20+
Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
2121

2222
## Features
2323

24-
- **Comprehensive Model Suite**: Includes modules for regression (`MambularRegressor`), classification (`MambularClassifier`), and distributional regression (`MambularLSS`), catering to a wide range of tabular data tasks.
25-
- **State-of-the-Art Architectures**: Leverages the Mamba architecture, known for its effectiveness in handling sequential and time-series data within a state-space modeling framework, adapted here for tabular data.
24+
- **Comprehensive Model Suite**: Includes modules for regression, classification, and distributional regression, catering to a wide range of tabular data tasks.
25+
- **State-of-the-Art Architectures**: Leverages various advanced architectures known for their effectiveness in handling tabular data. Mambular models include powerful Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) and can include bidirectional processing as well as feature interaction layers.
2626
- **Seamless Integration**: Designed to work effortlessly with scikit-learn, allowing for easy inclusion in existing machine learning pipelines, cross-validation, and hyperparameter tuning workflows.
2727
- **Extensive Preprocessing**: Comes with a powerful preprocessing module that supports a broad array of data transformation techniques, ensuring that your data is optimally prepared for model training.
2828
- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn.
2929
- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision.
3030

31+
32+
33+
## Models
34+
35+
| Model | Description |
36+
|---------------------|--------------------------------------------------------------------------------------------------|
37+
| `Mambular` | An advanced model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
38+
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
39+
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
40+
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
41+
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
42+
43+
All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
44+
Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
45+
46+
47+
3148
## Documentation
3249

3350
You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/).
@@ -50,9 +67,13 @@ Mambular automatically identifies the type of each feature in your dataset and a
5067
- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality.
5168
- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts.
5269
- **Decision Tree Binning**: Optionally, Mambular can use decision trees to find the optimal binning strategy for numerical features, enhancing model interpretability and performance.
53-
- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible
54-
- **Standardization**: Similarly, Standardization instead of Normalization can be used.
55-
- **PLE**: Periodic Linear Encodings for numerical features can enhance performance for tabular DL methods.
70+
- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible.
71+
- **Standardization**: Similarly, standardization instead of normalization can be used to scale features based on the mean and standard deviation.
72+
- **PLE (Periodic Linear Encoding)**: This technique can be applied to numerical features to enhance the performance of tabular deep learning methods by encoding periodicity.
73+
- **Quantile Transformation**: Numerical features can be transformed to follow a uniform or normal distribution, improving model robustness to outliers.
74+
- **Spline Transformation**: Applies piecewise polynomial functions to numerical features, capturing nonlinear relationships more effectively.
75+
- **Polynomial Features**: Generates polynomial and interaction features, increasing the feature space to capture more complex relationships within the data.
76+
5677

5778

5879
### Handling Missing Values
@@ -90,7 +111,7 @@ preds = model.predict_proba(X)
90111

91112
## Distributional Regression with MambularLSS
92113

93-
Mambular introduces a cutting-edge approach to distributional regression through its `MambularLSS` module, empowering users to model the full distribution of a response variable, not just its mean. This method is particularly valuable in scenarios where understanding the variability, skewness, or kurtosis of the response distribution is as crucial as predicting its central tendency.
114+
Mambular introduces an approach to distributional regression through its `MambularLSS` module, allowing users to model the full distribution of a response variable, not just its mean. This method is particularly valuable in scenarios where understanding the variability, skewness, or kurtosis of the response distribution is as crucial as predicting its central tendency. All available moedls in mambular are also available as distributional models.
94115

95116
### Key Features of MambularLSS:
96117

@@ -100,6 +121,7 @@ Mambular introduces a cutting-edge approach to distributional regression through
100121
- **Enhanced Predictive Uncertainty**: By modeling the full distribution, `MambularLSS` provides richer information on predictive uncertainty, enabling more robust decision-making processes in uncertain environments.
101122

102123

124+
103125
### Available Distribution Classes:
104126

105127
`MambularLSS` offers a wide range of distribution classes to cater to various statistical modeling needs. The available distribution classes include:
@@ -144,6 +166,86 @@ model.fit(
144166

145167
```
146168

169+
170+
### Implement Your Own Model
171+
172+
Mambular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from Mambular's `BaseModel`. Each Mambular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
173+
174+
One of the key advantages of using Mambular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
175+
176+
Here's how you can implement a custom model with Mambular:
177+
178+
179+
1. First, define your config:
180+
The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.
181+
182+
```python
183+
from dataclasses import dataclass
184+
185+
@dataclass
186+
class MyConfig:
187+
lr: float = 1e-04
188+
lr_patience: int = 10
189+
weight_decay: float = 1e-06
190+
lr_factor: float = 0.1
191+
```
192+
193+
2. Second, define your model:
194+
Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
195+
196+
```python
197+
from mambular.base_models import BaseModel
198+
import torch
199+
import torch.nn
200+
201+
class MyCustomModel(BaseModel):
202+
def __init__(
203+
self,
204+
cat_feature_info,
205+
num_feature_info,
206+
num_classes: int = 1,
207+
config=None,
208+
**kwargs,
209+
):
210+
super().__init__(**kwargs)
211+
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
212+
213+
input_dim = 0
214+
for feature_name, input_shape in num_feature_info.items():
215+
input_dim += input_shape
216+
for feature_name, input_shape in cat_feature_info.items():
217+
input_dim += 1
218+
219+
self.linear = nn.Linear(input_dim, num_classes)
220+
221+
def forward(self, num_features, cat_features):
222+
x = num_features + cat_features
223+
x = torch.cat(x, dim=1)
224+
225+
# Pass through linear layer
226+
output = self.linear(x)
227+
return output
228+
```
229+
230+
3. Leverage the Mambular API:
231+
You can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
232+
233+
```python
234+
from mambular.models import SklearnBaseRegressor
235+
236+
class MyRegressor(SklearnBaseRegressor):
237+
def __init__(self, **kwargs):
238+
super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
239+
```
240+
241+
4. Train and evaluate your model:
242+
You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
243+
244+
```python
245+
regressor = MyRegressor(numerical_preprocessing="ple")
246+
regressor.fit(X_train, y_train, max_epochs=50)
247+
```
248+
147249
## Citation
148250

149251
If you find this project useful in your research, please consider cite:
@@ -157,3 +259,5 @@ If you find this project useful in your research, please consider cite:
157259
```
158260

159261
## License
262+
263+
The entire codebase is under MIT license.
-284 KB
Binary file not shown.
-7.76 KB
Binary file not shown.
-696 KB
Binary file not shown.
-7.89 KB
Binary file not shown.
-46.3 KB
Binary file not shown.
-4.6 KB
Binary file not shown.

0 commit comments

Comments
 (0)