99from ..data_utils .datamodule import MambularDataModule
1010from ..preprocessing import Preprocessor
1111from lightning .pytorch .callbacks import ModelSummary
12+ from dataclasses import asdict , is_dataclass
1213
1314
1415class SklearnBaseRegressor (BaseEstimator ):
1516 def __init__ (self , model , config , ** kwargs ):
16- preprocessor_arg_names = [
17+ self . preprocessor_arg_names = [
1718 "n_bins" ,
1819 "numerical_preprocessing" ,
1920 "use_decision_tree_bins" ,
@@ -26,16 +27,18 @@ def __init__(self, model, config, **kwargs):
2627 ]
2728
2829 self .config_kwargs = {
29- k : v for k , v in kwargs .items () if k not in preprocessor_arg_names
30+ k : v for k , v in kwargs .items () if k not in self . preprocessor_arg_names
3031 }
3132 self .config = config (** self .config_kwargs )
3233
3334 preprocessor_kwargs = {
34- k : v for k , v in kwargs .items () if k in preprocessor_arg_names
35+ k : v for k , v in kwargs .items () if k in self . preprocessor_arg_names
3536 }
3637
3738 self .preprocessor = Preprocessor (** preprocessor_kwargs )
39+ self .base_model = model
3840 self .model = None
41+ self .built = False
3942
4043 # Raise a warning if task is set to 'classification'
4144 if preprocessor_kwargs .get ("task" ) == "classification" :
@@ -44,27 +47,24 @@ def __init__(self, model, config, **kwargs):
4447 UserWarning ,
4548 )
4649
47- self .base_model = model
48-
4950 def get_params (self , deep = True ):
5051 """
51- Get parameters for this estimator. Overrides the BaseEstimator method.
52+ Get parameters for this estimator.
5253
5354 Parameters
5455 ----------
5556 deep : bool, default=True
56- If True, returns the parameters for this estimator and contained sub-objects that are estimators.
57+ If True, will return the parameters for this estimator and contained subobjects that are estimators.
5758
5859 Returns
5960 -------
6061 params : dict
6162 Parameter names mapped to their values.
6263 """
63- params = self .config_kwargs # Parameters used to initialize DefaultConfig
64+ params = {}
65+ params .update (self .config_kwargs )
6466
65- # If deep=True, include parameters from nested components like preprocessor
6667 if deep :
67- # Assuming Preprocessor has a get_params method
6868 preprocessor_params = {
6969 "preprocessor__" + key : value
7070 for key , value in self .preprocessor .get_params ().items ()
@@ -75,35 +75,36 @@ def get_params(self, deep=True):
7575
7676 def set_params (self , ** parameters ):
7777 """
78- Set the parameters of this estimator. Overrides the BaseEstimator method.
78+ Set the parameters of this estimator.
7979
8080 Parameters
8181 ----------
8282 **parameters : dict
83- Estimator parameters to be set .
83+ Estimator parameters.
8484
8585 Returns
8686 -------
8787 self : object
88- The instance with updated parameters .
88+ Estimator instance.
8989 """
90- # Update config_kwargs with provided parameters
91- valid_config_keys = self .config_kwargs .keys ()
92- config_updates = {k : v for k , v in parameters .items () if k in valid_config_keys }
93- self .config_kwargs .update (config_updates )
94-
95- # Update the config object
96- for key , value in config_updates .items ():
97- setattr (self .config , key , value )
98-
99- # Handle preprocessor parameters (prefixed with 'preprocessor__')
90+ config_params = {
91+ k : v for k , v in parameters .items () if not k .startswith ("preprocessor__" )
92+ }
10093 preprocessor_params = {
10194 k .split ("__" )[1 ]: v
10295 for k , v in parameters .items ()
10396 if k .startswith ("preprocessor__" )
10497 }
98+
99+ if config_params :
100+ self .config_kwargs .update (config_params )
101+ if self .config is not None :
102+ for key , value in config_params .items ():
103+ setattr (self .config , key , value )
104+ else :
105+ self .config = self .config_class (** self .config_kwargs )
106+
105107 if preprocessor_params :
106- # Assuming Preprocessor has a set_params method
107108 self .preprocessor .set_params (** preprocessor_params )
108109
109110 return self
@@ -471,3 +472,24 @@ def evaluate(self, X, y_true, metrics=None):
471472 scores [metric_name ] = metric_func (y_true , predictions )
472473
473474 return scores
475+
476+ def score (self , X , y , metric = mean_squared_error ):
477+ """
478+ Calculate the score of the model using the specified metric.
479+
480+ Parameters
481+ ----------
482+ X : array-like or pd.DataFrame of shape (n_samples, n_features)
483+ The input samples to predict.
484+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
485+ The true target values against which to evaluate the predictions.
486+ metric : callable, default=mean_squared_error
487+ The metric function to use for evaluation. Must be a callable with the signature `metric(y_true, y_pred)`.
488+
489+ Returns
490+ -------
491+ score : float
492+ The score calculated using the specified metric.
493+ """
494+ predictions = self .predict (X )
495+ return metric (y , predictions )
0 commit comments