99from lightning .pytorch .callbacks import EarlyStopping , ModelCheckpoint , ModelSummary
1010from sklearn .base import BaseEstimator
1111from sklearn .metrics import accuracy_score , mean_squared_error
12- from skopt import gp_minimize
1312from torch .utils .data import DataLoader
1413from tqdm import tqdm
1514
1615from ...base_models .utils .lightning_wrapper import TaskModel
1716from ...data_utils .datamodule import MambularDataModule
1817from ...preprocessing import Preprocessor
19- from ...utils .config_mapper import (
20- activation_mapper ,
21- get_search_space ,
22- round_to_nearest_16 ,
23- )
18+
2419from ...utils .distributional_metrics import (
2520 beta_brier_score ,
2621 dirichlet_error ,
@@ -78,7 +73,7 @@ def __init__(self, model, config, **kwargs):
7873
7974 self .preprocessor = Preprocessor (** preprocessor_kwargs )
8075 self .task_model = None
81- self .base_model = model
76+ self .estimator = model
8277 self .built = False
8378
8479 # Raise a warning if task is set to 'classification'
@@ -246,7 +241,7 @@ def build_model(
246241 )
247242
248243 self .task_model = TaskModel (
249- model_class = self .base_model , # type: ignore
244+ model_class = self .estimator , # type: ignore
250245 num_classes = self .family .param_count ,
251246 family = self .family ,
252247 config = self .config ,
@@ -268,7 +263,7 @@ def build_model(
268263 )
269264
270265 self .built = True
271- self .base_model = self .task_model .base_model
266+ self .estimator = self .task_model .estimator
272267
273268 return self
274269
@@ -497,7 +492,7 @@ def predict(self, X, raw=False, device=None):
497492 predictions = torch .cat (predictions_list , dim = 0 )
498493
499494 # Check if ensemble is used
500- if getattr (self .base_model , "returns_ensemble" , False ): # If using ensemble
495+ if getattr (self .estimator , "returns_ensemble" , False ): # If using ensemble
501496 predictions = predictions .mean (dim = 1 ) # Average over ensemble dimension
502497
503498 if not raw :
@@ -642,7 +637,7 @@ def encode(self, X, batch_size=64):
642637 # Process data in batches
643638 encoded_outputs = []
644639 for num_features , cat_features in tqdm (data_loader ):
645- embeddings = self .task_model .base_model .encode (
640+ embeddings = self .task_model .estimator .encode (
646641 num_features , cat_features
647642 ) # Call your encode function
648643 encoded_outputs .append (embeddings )
0 commit comments