|
14 | 14 |
|
15 | 15 | from ...base_models.utils.lightning_wrapper import TaskModel |
16 | 16 | from ...data_utils.datamodule import MambularDataModule |
17 | | -from ...preprocessing import Preprocessor |
| 17 | +from pretab.preprocessor import Preprocessor |
18 | 18 |
|
19 | 19 | from ...utils.distributional_metrics import ( |
20 | 20 | beta_brier_score, |
@@ -245,8 +245,11 @@ def build_model( |
245 | 245 | num_classes=self.family.param_count, |
246 | 246 | family=self.family, |
247 | 247 | config=self.config, |
248 | | - cat_feature_info=self.data_module.cat_feature_info, |
249 | | - num_feature_info=self.data_module.num_feature_info, |
| 248 | + feature_information=( |
| 249 | + self.data_module.num_feature_info, |
| 250 | + self.data_module.cat_feature_info, |
| 251 | + self.data_module.embedding_feature_info, |
| 252 | + ), |
250 | 253 | lr=lr if lr is not None else self.config.lr, |
251 | 254 | lr_patience=( |
252 | 255 | lr_patience if lr_patience is not None else self.config.lr_patience |
@@ -454,11 +457,13 @@ def fit( |
454 | 457 | ) |
455 | 458 | self.trainer.fit(self.task_model, self.data_module) # type: ignore |
456 | 459 |
|
457 | | - best_model_path = checkpoint_callback.best_model_path |
458 | | - if best_model_path: |
459 | | - checkpoint = torch.load(best_model_path) |
| 460 | + self.best_model_path = checkpoint_callback.best_model_path |
| 461 | + if self.best_model_path: |
| 462 | + torch.serialization.add_safe_globals([type(self.config)]) |
| 463 | + checkpoint = torch.load(self.best_model_path, weights_only=False) |
460 | 464 | self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore |
461 | 465 |
|
| 466 | + self.is_fitted_ = True |
462 | 467 | return self |
463 | 468 |
|
464 | 469 | def predict(self, X, raw=False, device=None): |
|
0 commit comments