Skip to content

Commit 5a5a7cd

Browse files
committed
adapt build_method
1 parent d636b28 commit 5a5a7cd

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

mambular/models/sklearn_base_regressor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def build_model(
181181
y_val=y_val,
182182
val_size=val_size,
183183
random_state=random_state,
184-
regression=False,
184+
regression=True,
185185
**dataloader_kwargs
186186
)
187187

@@ -308,7 +308,7 @@ def fit(
308308
self : object
309309
The fitted regressor.
310310
"""
311-
if not self.built and not rebuild:
311+
if rebuild:
312312
if not isinstance(X, pd.DataFrame):
313313
X = pd.DataFrame(X)
314314
if isinstance(y, pd.Series):
@@ -346,6 +346,9 @@ def fit(
346346
weight_decay=weight_decay,
347347
)
348348

349+
else:
350+
assert self.built, "The model must be built before calling the fit method."
351+
349352
early_stop_callback = EarlyStopping(
350353
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
351354
)

0 commit comments

Comments
 (0)