Skip to content

Commit cc03c34

Browse files
authored
Merge pull request #260 from basf/lss_fix
Lss fix
2 parents 75c04f3 + a047fa0 commit cc03c34

19 files changed

Lines changed: 227 additions & 1865 deletions

.gitignore

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
#.idea/
161+
.DS_Store
162+
163+
dist/
164+
165+
# pkl files
166+
*.pkl
167+
168+
# logs and checkpoints
169+
examples/lightning_logs
170+
*.ckpt
171+
172+
docs/_build/doctrees/*
173+
docs/_build/html/*
174+
175+
176+
dev/*

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ pip install mamba-ssm
120120

121121
<h2> Preprocessing </h2>
122122

123-
Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
124-
Specify a default method, or a dictionary defining individual preprocessing methods for each feature.
123+
Mambular uses pretab preprocessing: https://github.com/OpenTabular/PreTab
124+
125+
Hence, datatypes etc. are detected automatically and all preprocessing methods from pretab as well as from Sklearn.preprocessing are available.
126+
Additionally, you can specify that each feature is preprocessed differently, according to your requirements, by setting the `feature_preprocessing={}`argument during model initialization.
127+
For an overview over all available methods: [pretab](https://github.com/OpenTabular/PreTab)
125128

126129
<h3> Data Type Detection and Transformation </h3>
127130

docs/api/preprocessing/Preprocessor.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/api/preprocessing/index.rst

Lines changed: 0 additions & 20 deletions
This file was deleted.

mamba_tabular_summary.pdf

-79.3 KB
Binary file not shown.

mambular/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.4.0"
20+
__version__ = "1.5.0"
2121

mambular/models/utils/sklearn_base_lss.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ...base_models.utils.lightning_wrapper import TaskModel
1616
from ...data_utils.datamodule import MambularDataModule
17-
from ...preprocessing import Preprocessor
17+
from pretab.preprocessor import Preprocessor
1818

1919
from ...utils.distributional_metrics import (
2020
beta_brier_score,
@@ -245,8 +245,11 @@ def build_model(
245245
num_classes=self.family.param_count,
246246
family=self.family,
247247
config=self.config,
248-
cat_feature_info=self.data_module.cat_feature_info,
249-
num_feature_info=self.data_module.num_feature_info,
248+
feature_information=(
249+
self.data_module.num_feature_info,
250+
self.data_module.cat_feature_info,
251+
self.data_module.embedding_feature_info,
252+
),
250253
lr=lr if lr is not None else self.config.lr,
251254
lr_patience=(
252255
lr_patience if lr_patience is not None else self.config.lr_patience
@@ -454,11 +457,13 @@ def fit(
454457
)
455458
self.trainer.fit(self.task_model, self.data_module) # type: ignore
456459

457-
best_model_path = checkpoint_callback.best_model_path
458-
if best_model_path:
459-
checkpoint = torch.load(best_model_path)
460+
self.best_model_path = checkpoint_callback.best_model_path
461+
if self.best_model_path:
462+
torch.serialization.add_safe_globals([type(self.config)])
463+
checkpoint = torch.load(self.best_model_path, weights_only=False)
460464
self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
461465

466+
self.is_fitted_ = True
462467
return self
463468

464469
def predict(self, X, raw=False, device=None):

mambular/models/utils/sklearn_parent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ...base_models.utils.lightning_wrapper import TaskModel
1414
from ...data_utils.datamodule import MambularDataModule
15-
from ...preprocessing import Preprocessor
15+
from pretab.preprocessor import Preprocessor
1616
from ...utils.config_mapper import (
1717
activation_mapper,
1818
get_search_space,

mambular/preprocessing/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)