11import warnings
2+ from collections .abc import Callable
23from typing import Optional
34
45import lightning as pl
910from sklearn .base import BaseEstimator
1011from sklearn .metrics import accuracy_score , log_loss , mean_squared_error
1112from skopt import gp_minimize
12- from collections .abc import Callable
13+ from torch .utils .data import DataLoader
14+ from tqdm import tqdm
15+
1316from ..base_models .lightning_wrapper import TaskModel
1417from ..data_utils .datamodule import MambularDataModule
1518from ..preprocessing import Preprocessor
16- from ..utils .config_mapper import (
17- activation_mapper ,
18- get_search_space ,
19- round_to_nearest_16 ,
20- )
21- from tqdm import tqdm
22- from torch .utils .data import DataLoader
19+ from ..utils .config_mapper import activation_mapper , get_search_space , round_to_nearest_16
2320
2421
2522class SklearnBaseClassifier (BaseEstimator ):
@@ -42,15 +39,11 @@ def __init__(self, model, config, **kwargs):
4239 ]
4340
4441 self .config_kwargs = {
45- k : v
46- for k , v in kwargs .items ()
47- if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
42+ k : v for k , v in kwargs .items () if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
4843 }
4944 self .config = config (** self .config_kwargs )
5045
51- preprocessor_kwargs = {
52- k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names
53- }
46+ preprocessor_kwargs = {k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names }
5447
5548 self .preprocessor = Preprocessor (** preprocessor_kwargs )
5649 self .task_model = None
@@ -70,8 +63,7 @@ def __init__(self, model, config, **kwargs):
7063 self .optimizer_kwargs = {
7164 k : v
7265 for k , v in kwargs .items ()
73- if k
74- not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
66+ if k not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
7567 and k .startswith ("optimizer_" )
7668 }
7769
@@ -92,10 +84,7 @@ def get_params(self, deep=True):
9284 params .update (self .config_kwargs )
9385
9486 if deep :
95- preprocessor_params = {
96- "prepro__" + key : value
97- for key , value in self .preprocessor .get_params ().items ()
98- }
87+ preprocessor_params = {"prepro__" + key : value for key , value in self .preprocessor .get_params ().items ()}
9988 params .update (preprocessor_params )
10089
10190 return params
@@ -113,14 +102,8 @@ def set_params(self, **parameters):
113102 self : object
114103 Estimator instance.
115104 """
116- config_params = {
117- k : v for k , v in parameters .items () if not k .startswith ("prepro__" )
118- }
119- preprocessor_params = {
120- k .split ("__" )[1 ]: v
121- for k , v in parameters .items ()
122- if k .startswith ("prepro__" )
123- }
105+ config_params = {k : v for k , v in parameters .items () if not k .startswith ("prepro__" )}
106+ preprocessor_params = {k .split ("__" )[1 ]: v for k , v in parameters .items () if k .startswith ("prepro__" )}
124107
125108 if config_params :
126109 self .config_kwargs .update (config_params )
@@ -218,9 +201,7 @@ def build_model(
218201 ** dataloader_kwargs ,
219202 )
220203
221- self .data_module .preprocess_data (
222- X , y , X_val , y_val , val_size = val_size , random_state = random_state
223- )
204+ self .data_module .preprocess_data (X , y , X_val , y_val , val_size = val_size , random_state = random_state )
224205
225206 num_classes = len (np .unique (np .array (y )))
226207
@@ -230,14 +211,10 @@ def build_model(
230211 config = self .config ,
231212 cat_feature_info = self .data_module .cat_feature_info ,
232213 num_feature_info = self .data_module .num_feature_info ,
233- lr_patience = (
234- lr_patience if lr_patience is not None else self .config .lr_patience
235- ),
214+ lr_patience = (lr_patience if lr_patience is not None else self .config .lr_patience ),
236215 lr = lr if lr is not None else self .config .lr ,
237216 lr_factor = lr_factor if lr_factor is not None else self .config .lr_factor ,
238- weight_decay = (
239- weight_decay if weight_decay is not None else self .config .weight_decay
240- ),
217+ weight_decay = (weight_decay if weight_decay is not None else self .config .weight_decay ),
241218 train_metrics = train_metrics ,
242219 val_metrics = val_metrics ,
243220 optimizer_type = self .optimizer_type ,
@@ -268,9 +245,7 @@ def get_number_of_params(self, requires_grad=True):
268245 If the model has not been built prior to calling this method.
269246 """
270247 if not self .built :
271- raise ValueError (
272- "The model must be built before the number of parameters can be estimated"
273- )
248+ raise ValueError ("The model must be built before the number of parameters can be estimated" )
274249 else :
275250 if requires_grad :
276251 return sum (p .numel () for p in self .task_model .parameters () if p .requires_grad ) # type: ignore
@@ -442,7 +417,7 @@ def predict(self, X, device=None):
442417 logits_list = self .trainer .predict (self .task_model , self .data_module )
443418
444419 # Concatenate predictions from all batches
445- logits = torch .cat (logits_list , dim = 0 )
420+ logits = torch .cat (logits_list , dim = 0 ) # type: ignore
446421
447422 # Check if ensemble is used
448423 if getattr (self .base_model , "returns_ensemble" , False ): # If using ensemble
@@ -619,9 +594,7 @@ def encode(self, X, batch_size=64):
619594 # Process data in batches
620595 encoded_outputs = []
621596 for num_features , cat_features in tqdm (data_loader ):
622- embeddings = self .task_model .base_model .encode (
623- num_features , cat_features
624- ) # Call your encode function
597+ embeddings = self .task_model .base_model .encode (num_features , cat_features ) # Call your encode function
625598 encoded_outputs .append (embeddings )
626599
627600 # Concatenate all encoded outputs
@@ -689,13 +662,9 @@ def optimize_hparams(
689662 best_val_loss = float ("inf" )
690663
691664 if X_val is not None and y_val is not None :
692- val_loss = self .evaluate (
693- X_val , y_val , metrics = {"Accuracy" : (accuracy_score , False )}
694- )["Accuracy" ]
665+ val_loss = self .evaluate (X_val , y_val , metrics = {"Accuracy" : (accuracy_score , False )})["Accuracy" ]
695666 else :
696- val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ][
697- "val_loss"
698- ]
667+ val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
699668
700669 best_val_loss = val_loss
701670 best_epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -721,9 +690,7 @@ def _objective(hyperparams):
721690 if param_value in activation_mapper :
722691 setattr (self .config , key , activation_mapper [param_value ])
723692 else :
724- raise ValueError (
725- f"Unknown activation function: { param_value } "
726- )
693+ raise ValueError (f"Unknown activation function: { param_value } " )
727694 else :
728695 setattr (self .config , key , param_value )
729696
@@ -732,15 +699,11 @@ def _objective(hyperparams):
732699 self .config .head_layer_sizes = head_layer_sizes [:head_layer_size_length ]
733700
734701 # Build the model with updated hyperparameters
735- self .build_model (
736- X , y , X_val = X_val , y_val = y_val , lr = self .config .lr , ** optimize_kwargs
737- )
702+ self .build_model (X , y , X_val = X_val , y_val = y_val , lr = self .config .lr , ** optimize_kwargs )
738703
739704 # Dynamically set the early pruning threshold
740705 if prune_by_epoch :
741- early_pruning_threshold = (
742- best_epoch_val_loss * 1.5
743- ) # Prune based on specific epoch loss
706+ early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
744707 else :
745708 # Prune based on the best overall validation loss
746709 early_pruning_threshold = best_val_loss * 1.5
@@ -752,19 +715,15 @@ def _objective(hyperparams):
752715 # Fit the model (limit epochs for faster optimization)
753716 try :
754717 # Wrap the risky operation (model fitting) in a try-except block
755- self .fit (
756- X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs , rebuild = False
757- )
718+ self .fit (X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs , rebuild = False )
758719
759720 # Evaluate validation loss
760721 if X_val is not None and y_val is not None :
761722 val_loss = self .evaluate (X_val , y_val , metrics = {"Mean Squared Error" : mean_squared_error })[ # type: ignore
762723 "Mean Squared Error"
763724 ]
764725 else :
765- val_loss = self .trainer .validate (self .task_model , self .data_module )[
766- 0
767- ]["val_loss" ]
726+ val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
768727
769728 # Pruning based on validation loss at specific epoch
770729 epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -781,21 +740,15 @@ def _objective(hyperparams):
781740
782741 except Exception as e :
783742 # Penalize the hyperparameter configuration with a large value
784- print (
785- f"Error encountered during fit with hyperparameters { hyperparams } : { e } "
786- )
787- return (
788- best_val_loss * 100
789- ) # Large value to discourage this configuration
743+ print (f"Error encountered during fit with hyperparameters { hyperparams } : { e } " )
744+ return best_val_loss * 100 # Large value to discourage this configuration
790745
791746 # Perform Bayesian optimization using scikit-optimize
792747 result = gp_minimize (_objective , param_space , n_calls = time , random_state = 42 )
793748
794749 # Update the model with the best-found hyperparameters
795750 best_hparams = result .x # type: ignore
796- head_layer_sizes = (
797- [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
798- )
751+ head_layer_sizes = [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
799752 layer_sizes = [] if "layer_sizes" in self .config .__dataclass_fields__ else None
800753
801754 # Iterate over the best hyperparameters found by optimization
0 commit comments