1414from ..base_models .lightning_wrapper import TaskModel
1515from ..data_utils .datamodule import MambularDataModule
1616from ..preprocessing import Preprocessor
17- from ..utils .config_mapper import (
18- activation_mapper ,
19- get_search_space ,
20- round_to_nearest_16 ,
21- )
17+ from ..utils .config_mapper import activation_mapper , get_search_space , round_to_nearest_16
2218
2319
2420class SklearnBaseRegressor (BaseEstimator ):
@@ -42,15 +38,11 @@ def __init__(self, model, config, **kwargs):
4238 ]
4339
4440 self .config_kwargs = {
45- k : v
46- for k , v in kwargs .items ()
47- if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
41+ k : v for k , v in kwargs .items () if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
4842 }
4943 self .config = config (** self .config_kwargs )
5044
51- preprocessor_kwargs = {
52- k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names
53- }
45+ preprocessor_kwargs = {k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names }
5446
5547 self .preprocessor = Preprocessor (** preprocessor_kwargs )
5648 self .base_model = model
@@ -70,8 +62,7 @@ def __init__(self, model, config, **kwargs):
7062 self .optimizer_kwargs = {
7163 k : v
7264 for k , v in kwargs .items ()
73- if k
74- not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
65+ if k not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
7566 and k .startswith ("optimizer_" )
7667 }
7768
@@ -92,10 +83,7 @@ def get_params(self, deep=True):
9283 params .update (self .config_kwargs )
9384
9485 if deep :
95- preprocessor_params = {
96- "prepro__" + key : value
97- for key , value in self .preprocessor .get_params ().items ()
98- }
86+ preprocessor_params = {"prepro__" + key : value for key , value in self .preprocessor .get_params ().items ()}
9987 params .update (preprocessor_params )
10088
10189 return params
@@ -113,14 +101,8 @@ def set_params(self, **parameters):
113101 self : object
114102 Estimator instance.
115103 """
116- config_params = {
117- k : v for k , v in parameters .items () if not k .startswith ("prepro__" )
118- }
119- preprocessor_params = {
120- k .split ("__" )[1 ]: v
121- for k , v in parameters .items ()
122- if k .startswith ("prepro__" )
123- }
104+ config_params = {k : v for k , v in parameters .items () if not k .startswith ("prepro__" )}
105+ preprocessor_params = {k .split ("__" )[1 ]: v for k , v in parameters .items () if k .startswith ("prepro__" )}
124106
125107 if config_params :
126108 self .config_kwargs .update (config_params )
@@ -240,13 +222,9 @@ def build_model(
240222 self .data_module .embedding_feature_info ,
241223 ),
242224 lr = lr if lr is not None else self .config .lr ,
243- lr_patience = (
244- lr_patience if lr_patience is not None else self .config .lr_patience
245- ),
225+ lr_patience = (lr_patience if lr_patience is not None else self .config .lr_patience ),
246226 lr_factor = lr_factor if lr_factor is not None else self .config .lr_factor ,
247- weight_decay = (
248- weight_decay if weight_decay is not None else self .config .weight_decay
249- ),
227+ weight_decay = (weight_decay if weight_decay is not None else self .config .weight_decay ),
250228 train_metrics = train_metrics ,
251229 val_metrics = val_metrics ,
252230 optimizer_type = self .optimizer_type ,
@@ -277,9 +255,7 @@ def get_number_of_params(self, requires_grad=True):
277255 If the model has not been built prior to calling this method.
278256 """
279257 if not self .built :
280- raise ValueError (
281- "The model must be built before the number of parameters can be estimated"
282- )
258+ raise ValueError ("The model must be built before the number of parameters can be estimated" )
283259 else :
284260 if requires_grad :
285261 return sum (p .numel () for p in self .task_model .parameters () if p .requires_grad ) # type: ignore
@@ -456,7 +432,7 @@ def predict(self, X, embeddings=None, device=None):
456432 predictions_list = self .trainer .predict (self .task_model , self .data_module )
457433
458434 # Concatenate predictions from all batches
459- predictions = torch .cat (predictions_list , dim = 0 )
435+ predictions = torch .cat (predictions_list , dim = 0 ) # type: ignore
460436
461437 # Check if ensemble is used
462438 if getattr (self .base_model , "returns_ensemble" , False ): # If using ensemble
@@ -553,9 +529,7 @@ def encode(self, X, embeddings=None, batch_size=64):
553529 # Process data in batches
554530 encoded_outputs = []
555531 for batch in tqdm (data_loader ):
556- embeddings = self .task_model .base_model .encode (
557- batch
558- ) # Call your encode function
532+ embeddings = self .task_model .base_model .encode (batch ) # Call your encode function
559533 encoded_outputs .append (embeddings )
560534
561535 # Concatenate all encoded outputs
@@ -633,13 +607,11 @@ def optimize_hparams(
633607 best_val_loss = float ("inf" )
634608
635609 if X_val is not None and y_val is not None :
636- val_loss = self .evaluate (
637- X_val , y_val , metrics = {"Mean Squared Error" : mean_squared_error }
638- )["Mean Squared Error" ]
639- else :
640- val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ][
641- "val_loss"
610+ val_loss = self .evaluate (X_val , y_val , metrics = {"Mean Squared Error" : mean_squared_error })[
611+ "Mean Squared Error"
642612 ]
613+ else :
614+ val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
643615
644616 best_val_loss = val_loss
645617 best_epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -665,9 +637,7 @@ def _objective(hyperparams):
665637 if param_value in activation_mapper :
666638 setattr (self .config , key , activation_mapper [param_value ])
667639 else :
668- raise ValueError (
669- f"Unknown activation function: { param_value } "
670- )
640+ raise ValueError (f"Unknown activation function: { param_value } " )
671641 else :
672642 setattr (self .config , key , param_value )
673643
@@ -689,9 +659,7 @@ def _objective(hyperparams):
689659
690660 # Dynamically set the early pruning threshold
691661 if prune_by_epoch :
692- early_pruning_threshold = (
693- best_epoch_val_loss * 1.5
694- ) # Prune based on specific epoch loss
662+ early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
695663 else :
696664 # Prune based on the best overall validation loss
697665 early_pruning_threshold = best_val_loss * 1.5
@@ -702,19 +670,15 @@ def _objective(hyperparams):
702670
703671 try :
704672 # Wrap the risky operation (model fitting) in a try-except block
705- self .fit (
706- X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs , rebuild = False
707- )
673+ self .fit (X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs , rebuild = False )
708674
709675 # Evaluate validation loss
710676 if X_val is not None and y_val is not None :
711- val_loss = self .evaluate (
712- X_val , y_val , metrics = { "Mean Squared Error" : mean_squared_error }
713- )[ "Mean Squared Error" ]
677+ val_loss = self .evaluate (X_val , y_val , metrics = { "Mean Squared Error" : mean_squared_error })[
678+ "Mean Squared Error"
679+ ]
714680 else :
715- val_loss = self .trainer .validate (self .task_model , self .data_module )[
716- 0
717- ]["val_loss" ]
681+ val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
718682
719683 # Pruning based on validation loss at specific epoch
720684 epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -731,21 +695,15 @@ def _objective(hyperparams):
731695
732696 except Exception as e :
733697 # Penalize the hyperparameter configuration with a large value
734- print (
735- f"Error encountered during fit with hyperparameters { hyperparams } : { e } "
736- )
737- return (
738- best_val_loss * 100
739- ) # Large value to discourage this configuration
698+ print (f"Error encountered during fit with hyperparameters { hyperparams } : { e } " )
699+ return best_val_loss * 100 # Large value to discourage this configuration
740700
741701 # Perform Bayesian optimization using scikit-optimize
742702 result = gp_minimize (_objective , param_space , n_calls = time , random_state = 42 )
743703
744704 # Update the model with the best-found hyperparameters
745705 best_hparams = result .x # type: ignore
746- head_layer_sizes = (
747- [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
748- )
706+ head_layer_sizes = [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
749707 layer_sizes = [] if "layer_sizes" in self .config .__dataclass_fields__ else None
750708
751709 # Iterate over the best hyperparameters found by optimization
0 commit comments