99from ..data_utils .datamodule import MambularDataModule
1010from ..preprocessing import Preprocessor
1111import numpy as np
12+ from lightning .pytorch .callbacks import ModelSummary
13+ from sklearn .metrics import log_loss
1214
1315
1416class SklearnBaseClassifier (BaseEstimator ):
@@ -49,23 +51,22 @@ def __init__(self, model, config, **kwargs):
4951
5052 def get_params (self , deep = True ):
5153 """
52- Get parameters for this estimator. Overrides the BaseEstimator method.
54+ Get parameters for this estimator.
5355
5456 Parameters
5557 ----------
5658 deep : bool, default=True
57- If True, returns the parameters for this estimator and contained sub-objects that are estimators.
59+ If True, will return the parameters for this estimator and contained subobjects that are estimators.
5860
5961 Returns
6062 -------
6163 params : dict
6264 Parameter names mapped to their values.
6365 """
64- params = self .config_kwargs # Parameters used to initialize DefaultConfig
66+ params = {}
67+ params .update (self .config_kwargs )
6568
66- # If deep=True, include parameters from nested components like preprocessor
6769 if deep :
68- # Assuming Preprocessor has a get_params method
6970 preprocessor_params = {
7071 "preprocessor__" + key : value
7172 for key , value in self .preprocessor .get_params ().items ()
@@ -76,35 +77,36 @@ def get_params(self, deep=True):
7677
7778 def set_params (self , ** parameters ):
7879 """
79- Set the parameters of this estimator. Overrides the BaseEstimator method.
80+ Set the parameters of this estimator.
8081
8182 Parameters
8283 ----------
8384 **parameters : dict
84- Estimator parameters to be set .
85+ Estimator parameters.
8586
8687 Returns
8788 -------
8889 self : object
89- The instance with updated parameters .
90+ Estimator instance.
9091 """
91- # Update config_kwargs with provided parameters
92- valid_config_keys = self .config_kwargs .keys ()
93- config_updates = {k : v for k , v in parameters .items () if k in valid_config_keys }
94- self .config_kwargs .update (config_updates )
95-
96- # Update the config object
97- for key , value in config_updates .items ():
98- setattr (self .config , key , value )
99-
100- # Handle preprocessor parameters (prefixed with 'preprocessor__')
92+ config_params = {
93+ k : v for k , v in parameters .items () if not k .startswith ("preprocessor__" )
94+ }
10195 preprocessor_params = {
10296 k .split ("__" )[1 ]: v
10397 for k , v in parameters .items ()
10498 if k .startswith ("preprocessor__" )
10599 }
100+
101+ if config_params :
102+ self .config_kwargs .update (config_params )
103+ if self .config is not None :
104+ for key , value in config_params .items ():
105+ setattr (self .config , key , value )
106+ else :
107+ self .config = self .config_class (** self .config_kwargs )
108+
106109 if preprocessor_params :
107- # Assuming Preprocessor has a set_params method
108110 self .preprocessor .set_params (** preprocessor_params )
109111
110112 return self
@@ -368,12 +370,16 @@ def fit(
368370 )
369371
370372 # Initialize the trainer and train the model
371- trainer = pl .Trainer (
373+ self . trainer = pl .Trainer (
372374 max_epochs = max_epochs ,
373- callbacks = [early_stop_callback , checkpoint_callback ],
375+ callbacks = [
376+ early_stop_callback ,
377+ checkpoint_callback ,
378+ ModelSummary (max_depth = 2 ),
379+ ],
374380 ** trainer_kwargs
375381 )
376- trainer .fit (self .model , self .data_module )
382+ self . trainer .fit (self .model , self .data_module )
377383
378384 best_model_path = checkpoint_callback .best_model_path
379385 if best_model_path :
@@ -555,3 +561,33 @@ def evaluate(self, X, y_true, metrics=None):
555561 scores [metric_name ] = metric_func (y_true , predictions )
556562
557563 return scores
564+
565+ def score (self , X , y , metric = (log_loss , True )):
566+ """
567+ Calculate the score of the model using the specified metric.
568+
569+ Parameters
570+ ----------
571+ X : array-like or pd.DataFrame of shape (n_samples, n_features)
572+ The input samples to predict.
573+ y : array-like of shape (n_samples,)
574+ The true class labels against which to evaluate the predictions.
575+ metric : tuple, default=(log_loss, True)
576+ A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
577+
578+ Returns
579+ -------
580+ score : float
581+ The score calculated using the specified metric.
582+ """
583+ metric_func , use_proba = metric
584+
585+ if not isinstance (X , pd .DataFrame ):
586+ X = pd .DataFrame (X )
587+
588+ if use_proba :
589+ probabilities = self .predict_proba (X )
590+ return metric_func (y , probabilities )
591+ else :
592+ predictions = self .predict (X )
593+ return metric_func (y , predictions )
0 commit comments