Skip to content

Commit d04e7e4

Browse files
authored
Merge pull request #39 from basf/restructure
adapt example custom model
2 parents 30fc1fb + f6e2341 commit d04e7e4

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

README.md

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,28 @@ model.fit(
162162

163163

164164
### Implement your own model:
165-
mambular allows users to easily integrate their custom models into the existing logic. Simply create a pytorch model and define its forward pass. Instead of inheriting from nn.Module, inherit from mambulars BaseModel. Each mambular model takse three arguments. The number of classes, e.g. = 1 for regression or = 2 for binary classification. For distributional regression, while this argument must be provided, it is determined automatically depending on the chosen distribution. Additionally, it takes two arguments directly passed from preprocessor. The cat_feature_info and num_feature_info for categorical and numerical feature information of e.g. the provided shape. Additionally, you can provide a config argument, which you can either use simialr to the implemented models, or leave empty as shown below. A custom model could hence look just like this:
165+
mambular allows users to easily integrate their custom models into the existing logic. Simply create a pytorch model and define its forward pass. Instead of inheriting from nn.Module, inherit from mambulars BaseModel. Each mambular model takse three arguments. The number of classes, e.g. = 1 for regression or = 2 for binary classification. For distributional regression, while this argument must be provided, it is determined automatically depending on the chosen distribution. Additionally, it takes two arguments directly passed from preprocessor. The cat_feature_info and num_feature_info for categorical and numerical feature information of e.g. the provided shape. Additionally, you can provide a config argument, which you can either implement similarly to the implemented configs, or simply use one of the Default Configs provided. A custom model could hence look just like this:
166166

167167

168+
1. First, define your config
169+
170+
```python
171+
from dataclasses import dataclass
172+
173+
@dataclass
174+
class MyConfig:
175+
lr: float = 1e-04
176+
lr_patience: int = 10
177+
weight_decay: float = 1e-06
178+
lr_factor: float = 0.1
179+
```
180+
181+
2. Second, define your model just as you would for a nn.Module. Simply define the architecture and the forward pass
182+
168183
```python
169184
from mambular.base_models import BaseModel
185+
import torch
186+
import torch.nn
170187

171188
class MyCustomModel(BaseModel):
172189
def __init__(
@@ -197,19 +214,24 @@ class MyCustomModel(BaseModel):
197214
return output
198215
```
199216

200-
To leverage the mambular API, you can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
217+
3. To leverage the mambular API, you can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
201218

202219
```python
203220
from mambular.models import SklearnBaseRegressor
204221

205222
class MyRegressor(SklearnBaseRegressor):
206223
def __init__(self, **kwargs):
207-
super().__init__(model=MyCustomModel, config=None, **kwargs)
224+
super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
208225
```
209226

210-
Subsequently, you can fit, evaluate and predict with your model just like with any other mambualr model.
227+
4. Subsequently, you can fit, evaluate and predict with your model just like with any other mambualr model.
211228
To achieve the same for classification or disrtibutional regression, instead of inheriting from the SklearnbaseRegressor, simply inherit from the SklearnBaseClassifier and SklearnBaseLSS.
212229

230+
```python
231+
regressor = MyRegressor(numerical_preprocessing="ple")
232+
regressor.fit(X_train, y_train, max_epochs=50)
233+
```
234+
213235
## Citation
214236

215237
If you find this project useful in your research, please consider cite:

0 commit comments

Comments
 (0)