Skip to content

Commit 85f465c

Browse files
committed
fix lss bug
1 parent 75c04f3 commit 85f465c

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

mambular/models/utils/sklearn_base_lss.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ...base_models.utils.lightning_wrapper import TaskModel
1616
from ...data_utils.datamodule import MambularDataModule
17-
from ...preprocessing import Preprocessor
17+
from pretab.preprocessor import Preprocessor
1818

1919
from ...utils.distributional_metrics import (
2020
beta_brier_score,
@@ -245,8 +245,11 @@ def build_model(
245245
num_classes=self.family.param_count,
246246
family=self.family,
247247
config=self.config,
248-
cat_feature_info=self.data_module.cat_feature_info,
249-
num_feature_info=self.data_module.num_feature_info,
248+
feature_information=(
249+
self.data_module.num_feature_info,
250+
self.data_module.cat_feature_info,
251+
self.data_module.embedding_feature_info,
252+
),
250253
lr=lr if lr is not None else self.config.lr,
251254
lr_patience=(
252255
lr_patience if lr_patience is not None else self.config.lr_patience
@@ -454,11 +457,13 @@ def fit(
454457
)
455458
self.trainer.fit(self.task_model, self.data_module) # type: ignore
456459

457-
best_model_path = checkpoint_callback.best_model_path
458-
if best_model_path:
459-
checkpoint = torch.load(best_model_path)
460+
self.best_model_path = checkpoint_callback.best_model_path
461+
if self.best_model_path:
462+
torch.serialization.add_safe_globals([type(self.config)])
463+
checkpoint = torch.load(self.best_model_path, weights_only=False)
460464
self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
461465

466+
self.is_fitted_ = True
462467
return self
463468

464469
def predict(self, X, raw=False, device=None):

0 commit comments

Comments
 (0)