Skip to content

Commit af6641b

Browse files
authored
Merge pull request #110 from basf/hotfix
Hotfix
2 parents a327ef8 + 4385952 commit af6641b

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def fit(
316316
self : object
317317
The fitted classifier.
318318
"""
319-
if not self.built and not rebuild:
319+
if rebuild:
320320
if not isinstance(X, pd.DataFrame):
321321
X = pd.DataFrame(X)
322322
if isinstance(y, pd.Series):

mambular/models/tabularnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class TabulaRNNRegressor(SklearnBaseRegressor):
8686
The number of knots to be used in spline transformations.
8787
"""
8888

89+
def __init__(self, **kwargs):
90+
super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
91+
8992

9093
class TabulaRNNClassifier(SklearnBaseClassifier):
9194
"""

0 commit comments

Comments
 (0)