88import torch
99from lightning .pytorch .callbacks import EarlyStopping , ModelCheckpoint , ModelSummary
1010from sklearn .base import BaseEstimator
11- from sklearn .metrics import accuracy_score , log_loss , mean_squared_error
11+ from sklearn .metrics import accuracy_score , log_loss
1212from skopt import gp_minimize
1313from torch .utils .data import DataLoader
1414from tqdm import tqdm
1515
1616from ..base_models .lightning_wrapper import TaskModel
1717from ..data_utils .datamodule import MambularDataModule
1818from ..preprocessing import Preprocessor
19- from ..utils .config_mapper import activation_mapper , get_search_space , round_to_nearest_16
19+ from ..utils .config_mapper import (
20+ activation_mapper ,
21+ get_search_space ,
22+ round_to_nearest_16 ,
23+ )
2024
2125
2226class SklearnBaseClassifier (BaseEstimator ):
@@ -39,11 +43,15 @@ def __init__(self, model, config, **kwargs):
3943 ]
4044
4145 self .config_kwargs = {
42- k : v for k , v in kwargs .items () if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
46+ k : v
47+ for k , v in kwargs .items ()
48+ if k not in self .preprocessor_arg_names and not k .startswith ("optimizer" )
4349 }
4450 self .config = config (** self .config_kwargs )
4551
46- preprocessor_kwargs = {k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names }
52+ preprocessor_kwargs = {
53+ k : v for k , v in kwargs .items () if k in self .preprocessor_arg_names
54+ }
4755
4856 self .preprocessor = Preprocessor (** preprocessor_kwargs )
4957 self .task_model = None
@@ -63,7 +71,8 @@ def __init__(self, model, config, **kwargs):
6371 self .optimizer_kwargs = {
6472 k : v
6573 for k , v in kwargs .items ()
66- if k not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
74+ if k
75+ not in ["lr" , "weight_decay" , "patience" , "lr_patience" , "optimizer_type" ]
6776 and k .startswith ("optimizer_" )
6877 }
6978
@@ -84,7 +93,10 @@ def get_params(self, deep=True):
8493 params .update (self .config_kwargs )
8594
8695 if deep :
87- preprocessor_params = {"prepro__" + key : value for key , value in self .preprocessor .get_params ().items ()}
96+ preprocessor_params = {
97+ "prepro__" + key : value
98+ for key , value in self .preprocessor .get_params ().items ()
99+ }
88100 params .update (preprocessor_params )
89101
90102 return params
@@ -102,8 +114,14 @@ def set_params(self, **parameters):
102114 self : object
103115 Estimator instance.
104116 """
105- config_params = {k : v for k , v in parameters .items () if not k .startswith ("prepro__" )}
106- preprocessor_params = {k .split ("__" )[1 ]: v for k , v in parameters .items () if k .startswith ("prepro__" )}
117+ config_params = {
118+ k : v for k , v in parameters .items () if not k .startswith ("prepro__" )
119+ }
120+ preprocessor_params = {
121+ k .split ("__" )[1 ]: v
122+ for k , v in parameters .items ()
123+ if k .startswith ("prepro__" )
124+ }
107125
108126 if config_params :
109127 self .config_kwargs .update (config_params )
@@ -125,6 +143,8 @@ def build_model(
125143 val_size : float = 0.2 ,
126144 X_val = None ,
127145 y_val = None ,
146+ embeddings = None ,
147+ embeddings_val = None ,
128148 random_state : int = 101 ,
129149 batch_size : int = 128 ,
130150 shuffle : bool = True ,
@@ -201,20 +221,36 @@ def build_model(
201221 ** dataloader_kwargs ,
202222 )
203223
204- self .data_module .preprocess_data (X , y , X_val , y_val , val_size = val_size , random_state = random_state )
224+ self .data_module .preprocess_data (
225+ X ,
226+ y ,
227+ X_val = X_val ,
228+ y_val = y_val ,
229+ embeddings_train = embeddings ,
230+ embeddings_val = embeddings_val ,
231+ val_size = val_size ,
232+ random_state = random_state ,
233+ )
205234
206235 num_classes = len (np .unique (np .array (y )))
207236
208237 self .task_model = TaskModel (
209238 model_class = self .base_model , # type: ignore
210239 num_classes = num_classes ,
211240 config = self .config ,
212- cat_feature_info = self .data_module .cat_feature_info ,
213- num_feature_info = self .data_module .num_feature_info ,
214- lr_patience = (lr_patience if lr_patience is not None else self .config .lr_patience ),
241+ feature_information = (
242+ self .data_module .num_feature_info ,
243+ self .data_module .cat_feature_info ,
244+ self .data_module .embedding_feature_info ,
245+ ),
246+ lr_patience = (
247+ lr_patience if lr_patience is not None else self .config .lr_patience
248+ ),
215249 lr = lr if lr is not None else self .config .lr ,
216250 lr_factor = lr_factor if lr_factor is not None else self .config .lr_factor ,
217- weight_decay = (weight_decay if weight_decay is not None else self .config .weight_decay ),
251+ weight_decay = (
252+ weight_decay if weight_decay is not None else self .config .weight_decay
253+ ),
218254 train_metrics = train_metrics ,
219255 val_metrics = val_metrics ,
220256 optimizer_type = self .optimizer_type ,
@@ -245,7 +281,9 @@ def get_number_of_params(self, requires_grad=True):
245281 If the model has not been built prior to calling this method.
246282 """
247283 if not self .built :
248- raise ValueError ("The model must be built before the number of parameters can be estimated" )
284+ raise ValueError (
285+ "The model must be built before the number of parameters can be estimated"
286+ )
249287 else :
250288 if requires_grad :
251289 return sum (p .numel () for p in self .task_model .parameters () if p .requires_grad ) # type: ignore
@@ -259,6 +297,8 @@ def fit(
259297 val_size : float = 0.2 ,
260298 X_val = None ,
261299 y_val = None ,
300+ embeddings = None ,
301+ embeddings_val = None ,
262302 max_epochs : int = 100 ,
263303 random_state : int = 101 ,
264304 batch_size : int = 128 ,
@@ -340,6 +380,8 @@ def fit(
340380 val_size = val_size ,
341381 X_val = X_val ,
342382 y_val = y_val ,
383+ embeddings = embeddings ,
384+ embeddings_val = embeddings_val ,
343385 random_state = random_state ,
344386 batch_size = batch_size ,
345387 shuffle = shuffle ,
@@ -390,7 +432,7 @@ def fit(
390432
391433 return self
392434
393- def predict (self , X , device = None ):
435+ def predict (self , X , embeddings = None , device = None ):
394436 """Predicts target labels for the given input samples.
395437
396438 Parameters
@@ -408,7 +450,7 @@ def predict(self, X, device=None):
408450 raise ValueError ("The model or data module has not been fitted yet." )
409451
410452 # Preprocess the data using the data module
411- self .data_module .assign_predict_dataset (X )
453+ self .data_module .assign_predict_dataset (X , embeddings )
412454
413455 # Set model to evaluation mode
414456 self .task_model .eval ()
@@ -438,7 +480,7 @@ def predict(self, X, device=None):
438480 # Convert predictions to NumPy array and return
439481 return predictions .cpu ().numpy ()
440482
441- def predict_proba (self , X , device = None ):
483+ def predict_proba (self , X , embeddings = None , device = None ):
442484 """Predicts class probabilities for the given input samples.
443485
444486 Parameters
@@ -482,7 +524,7 @@ def predict_proba(self, X, device=None):
482524 # Convert probabilities to NumPy array and return
483525 return probabilities .cpu ().numpy ()
484526
485- def evaluate (self , X , y_true , metrics = None ):
527+ def evaluate (self , X , y_true , embeddings = None , metrics = None ):
486528 """Evaluate the model on the given data using specified metrics.
487529
488530 Parameters
@@ -491,6 +533,8 @@ def evaluate(self, X, y_true, metrics=None):
491533 The input samples to predict.
492534 y_true : array-like of shape (n_samples,)
493535 The true class labels against which to evaluate the predictions.
536+ embneddings : array-like or list of shape(n_samples, dimension)
537+ List or array with embeddings for unstructured data inputs
494538 metrics : dict
495539 A dictionary where keys are metric names and values are tuples containing the metric function
496540 and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
@@ -518,11 +562,11 @@ def evaluate(self, X, y_true, metrics=None):
518562
519563 # Generate class probabilities if any metric requires them
520564 if any (use_proba for _ , use_proba in metrics .values ()):
521- probabilities = self .predict_proba (X )
565+ probabilities = self .predict_proba (X , embeddings )
522566
523567 # Generate class labels if any metric requires them
524568 if any (not use_proba for _ , use_proba in metrics .values ()):
525- predictions = self .predict (X )
569+ predictions = self .predict (X , embeddings )
526570
527571 # Compute each metric
528572 for metric_name , (metric_func , use_proba ) in metrics .items ():
@@ -533,7 +577,7 @@ def evaluate(self, X, y_true, metrics=None):
533577
534578 return scores
535579
536- def score (self , X , y , metric = (log_loss , True )):
580+ def score (self , X , y , embeddings = None , metric = (log_loss , True )):
537581 """Calculate the score of the model using the specified metric.
538582
539583 Parameters
@@ -557,13 +601,13 @@ def score(self, X, y, metric=(log_loss, True)):
557601 X = pd .DataFrame (X )
558602
559603 if use_proba :
560- probabilities = self .predict_proba (X )
604+ probabilities = self .predict_proba (X , embeddings )
561605 return metric_func (y , probabilities )
562606 else :
563- predictions = self .predict (X )
607+ predictions = self .predict (X , embeddings )
564608 return metric_func (y , predictions )
565609
566- def encode (self , X , batch_size = 64 ):
610+ def encode (self , X , embeddings = None , batch_size = 64 ):
567611 """
568612 Encodes input data using the trained model's embedding layer.
569613
@@ -587,14 +631,16 @@ def encode(self, X, batch_size=64):
587631 # Ensure model and data module are initialized
588632 if self .task_model is None or self .data_module is None :
589633 raise ValueError ("The model or data module has not been fitted yet." )
590- encoded_dataset = self .data_module .preprocess_new_data (X )
634+ encoded_dataset = self .data_module .preprocess_new_data (X , embeddings )
591635
592636 data_loader = DataLoader (encoded_dataset , batch_size = batch_size , shuffle = False )
593637
594638 # Process data in batches
595639 encoded_outputs = []
596- for num_features , cat_features in tqdm (data_loader ):
597- embeddings = self .task_model .base_model .encode (num_features , cat_features ) # Call your encode function
640+ for batch in tqdm (data_loader ):
641+ embeddings = self .task_model .base_model .encode (
642+ batch
643+ ) # Call your encode function
598644 encoded_outputs .append (embeddings )
599645
600646 # Concatenate all encoded outputs
@@ -608,6 +654,8 @@ def optimize_hparams(
608654 y ,
609655 X_val = None ,
610656 y_val = None ,
657+ embeddings = None ,
658+ embeddings_val = None ,
611659 time = 100 ,
612660 max_epochs = 200 ,
613661 prune_by_epoch = True ,
@@ -658,13 +706,25 @@ def optimize_hparams(
658706 )
659707
660708 # Initial model fitting to get the baseline validation loss
661- self .fit (X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs )
709+ self .fit (
710+ X ,
711+ y ,
712+ X_val = X_val ,
713+ y_val = y_val ,
714+ embeddings = embeddings ,
715+ embeddings_val = embeddings_val ,
716+ max_epochs = max_epochs ,
717+ )
662718 best_val_loss = float ("inf" )
663719
664720 if X_val is not None and y_val is not None :
665- val_loss = self .evaluate (X_val , y_val , metrics = {"Accuracy" : (accuracy_score , False )})["Accuracy" ]
721+ val_loss = self .evaluate (
722+ X_val , y_val , metrics = {"Accuracy" : (accuracy_score , False )}
723+ )["Accuracy" ]
666724 else :
667- val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
725+ val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ][
726+ "val_loss"
727+ ]
668728
669729 best_val_loss = val_loss
670730 best_epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -690,7 +750,9 @@ def _objective(hyperparams):
690750 if param_value in activation_mapper :
691751 setattr (self .config , key , activation_mapper [param_value ])
692752 else :
693- raise ValueError (f"Unknown activation function: { param_value } " )
753+ raise ValueError (
754+ f"Unknown activation function: { param_value } "
755+ )
694756 else :
695757 setattr (self .config , key , param_value )
696758
@@ -699,11 +761,15 @@ def _objective(hyperparams):
699761 self .config .head_layer_sizes = head_layer_sizes [:head_layer_size_length ]
700762
701763 # Build the model with updated hyperparameters
702- self .build_model (X , y , X_val = X_val , y_val = y_val , lr = self .config .lr , ** optimize_kwargs )
764+ self .build_model (
765+ X , y , X_val = X_val , y_val = y_val , lr = self .config .lr , ** optimize_kwargs
766+ )
703767
704768 # Dynamically set the early pruning threshold
705769 if prune_by_epoch :
706- early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
770+ early_pruning_threshold = (
771+ best_epoch_val_loss * 1.5
772+ ) # Prune based on specific epoch loss
707773 else :
708774 # Prune based on the best overall validation loss
709775 early_pruning_threshold = best_val_loss * 1.5
@@ -715,15 +781,26 @@ def _objective(hyperparams):
715781 # Fit the model (limit epochs for faster optimization)
716782 try :
717783 # Wrap the risky operation (model fitting) in a try-except block
718- self .fit (X , y , X_val = X_val , y_val = y_val , max_epochs = max_epochs , rebuild = False )
784+ self .fit (
785+ X ,
786+ y ,
787+ X_val = X_val ,
788+ y_val = y_val ,
789+ embeddings = embeddings ,
790+ embeddings_val = embeddings_val ,
791+ max_epochs = max_epochs ,
792+ rebuild = False ,
793+ )
719794
720795 # Evaluate validation loss
721796 if X_val is not None and y_val is not None :
722- val_loss = self .evaluate (X_val , y_val , metrics = {"Mean Squared Error " : mean_squared_error })[ # type: ignore
797+ val_loss = self .evaluate (X_val , y_val , metrics = {"Accuracy " : ( accuracy_score , False ) })[ # type: ignore
723798 "Mean Squared Error"
724799 ]
725800 else :
726- val_loss = self .trainer .validate (self .task_model , self .data_module )[0 ]["val_loss" ]
801+ val_loss = self .trainer .validate (self .task_model , self .data_module )[
802+ 0
803+ ]["val_loss" ]
727804
728805 # Pruning based on validation loss at specific epoch
729806 epoch_val_loss = self .task_model .epoch_val_loss_at ( # type: ignore
@@ -740,15 +817,21 @@ def _objective(hyperparams):
740817
741818 except Exception as e :
742819 # Penalize the hyperparameter configuration with a large value
743- print (f"Error encountered during fit with hyperparameters { hyperparams } : { e } " )
744- return best_val_loss * 100 # Large value to discourage this configuration
820+ print (
821+ f"Error encountered during fit with hyperparameters { hyperparams } : { e } "
822+ )
823+ return (
824+ best_val_loss * 100
825+ ) # Large value to discourage this configuration
745826
746827 # Perform Bayesian optimization using scikit-optimize
747828 result = gp_minimize (_objective , param_space , n_calls = time , random_state = 42 )
748829
749830 # Update the model with the best-found hyperparameters
750831 best_hparams = result .x # type: ignore
751- head_layer_sizes = [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
832+ head_layer_sizes = (
833+ [] if "head_layer_sizes" in self .config .__dataclass_fields__ else None
834+ )
752835 layer_sizes = [] if "layer_sizes" in self .config .__dataclass_fields__ else None
753836
754837 # Iterate over the best hyperparameters found by optimization
0 commit comments