Skip to content

Commit bd941d2

Browse files
committed
fixed set and get_params functinoality
1 parent b16af74 commit bd941d2

3 files changed

Lines changed: 115 additions & 62 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..preprocessing import Preprocessor
1111
import numpy as np
1212
from lightning.pytorch.callbacks import ModelSummary
13+
from sklearn.metrics import log_loss
1314

1415

1516
class SklearnBaseClassifier(BaseEstimator):
@@ -49,23 +50,22 @@ def __init__(self, model, config, **kwargs):
4950

5051
def get_params(self, deep=True):
5152
"""
52-
Get parameters for this estimator. Overrides the BaseEstimator method.
53+
Get parameters for this estimator.
5354
5455
Parameters
5556
----------
5657
deep : bool, default=True
57-
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
58+
If True, will return the parameters for this estimator and contained subobjects that are estimators.
5859
5960
Returns
6061
-------
6162
params : dict
6263
Parameter names mapped to their values.
6364
"""
64-
params = self.config_kwargs # Parameters used to initialize DefaultConfig
65+
params = {}
66+
params.update(self.config_kwargs)
6567

66-
# If deep=True, include parameters from nested components like preprocessor
6768
if deep:
68-
# Assuming Preprocessor has a get_params method
6969
preprocessor_params = {
7070
"preprocessor__" + key: value
7171
for key, value in self.preprocessor.get_params().items()
@@ -76,35 +76,36 @@ def get_params(self, deep=True):
7676

7777
def set_params(self, **parameters):
7878
"""
79-
Set the parameters of this estimator. Overrides the BaseEstimator method.
79+
Set the parameters of this estimator.
8080
8181
Parameters
8282
----------
8383
**parameters : dict
84-
Estimator parameters to be set.
84+
Estimator parameters.
8585
8686
Returns
8787
-------
8888
self : object
89-
The instance with updated parameters.
89+
Estimator instance.
9090
"""
91-
# Update config_kwargs with provided parameters
92-
valid_config_keys = self.config_kwargs.keys()
93-
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
94-
self.config_kwargs.update(config_updates)
95-
96-
# Update the config object
97-
for key, value in config_updates.items():
98-
setattr(self.config, key, value)
99-
100-
# Handle preprocessor parameters (prefixed with 'preprocessor__')
91+
config_params = {
92+
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
93+
}
10194
preprocessor_params = {
10295
k.split("__")[1]: v
10396
for k, v in parameters.items()
10497
if k.startswith("preprocessor__")
10598
}
99+
100+
if config_params:
101+
self.config_kwargs.update(config_params)
102+
if self.config is not None:
103+
for key, value in config_params.items():
104+
setattr(self.config, key, value)
105+
else:
106+
self.config = self.config_class(**self.config_kwargs)
107+
106108
if preprocessor_params:
107-
# Assuming Preprocessor has a set_params method
108109
self.preprocessor.set_params(**preprocessor_params)
109110

110111
return self
@@ -559,3 +560,33 @@ def evaluate(self, X, y_true, metrics=None):
559560
scores[metric_name] = metric_func(y_true, predictions)
560561

561562
return scores
563+
564+
def score(self, X, y, metric=(log_loss, True)):
565+
"""
566+
Calculate the score of the model using the specified metric.
567+
568+
Parameters
569+
----------
570+
X : array-like or pd.DataFrame of shape (n_samples, n_features)
571+
The input samples to predict.
572+
y : array-like of shape (n_samples,)
573+
The true class labels against which to evaluate the predictions.
574+
metric : tuple, default=(log_loss, True)
575+
A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
576+
577+
Returns
578+
-------
579+
score : float
580+
The score calculated using the specified metric.
581+
"""
582+
metric_func, use_proba = metric
583+
584+
if not isinstance(X, pd.DataFrame):
585+
X = pd.DataFrame(X)
586+
587+
if use_proba:
588+
probabilities = self.predict_proba(X)
589+
return metric_func(y, probabilities)
590+
else:
591+
predictions = self.predict(X)
592+
return metric_func(y, predictions)

mambular/models/sklearn_base_lss.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,22 @@ def __init__(self, model, config, **kwargs):
7171

7272
def get_params(self, deep=True):
7373
"""
74-
Get parameters for this estimator. Overrides the BaseEstimator method.
74+
Get parameters for this estimator.
7575
7676
Parameters
7777
----------
7878
deep : bool, default=True
79-
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
79+
If True, will return the parameters for this estimator and contained subobjects that are estimators.
8080
8181
Returns
8282
-------
8383
params : dict
8484
Parameter names mapped to their values.
8585
"""
86-
params = self.config_kwargs # Parameters used to initialize DefaultConfig
86+
params = {}
87+
params.update(self.config_kwargs)
8788

88-
# If deep=True, include parameters from nested components like preprocessor
8989
if deep:
90-
# Assuming Preprocessor has a get_params method
9190
preprocessor_params = {
9291
"preprocessor__" + key: value
9392
for key, value in self.preprocessor.get_params().items()
@@ -98,35 +97,36 @@ def get_params(self, deep=True):
9897

9998
def set_params(self, **parameters):
10099
"""
101-
Set the parameters of this estimator. Overrides the BaseEstimator method.
100+
Set the parameters of this estimator.
102101
103102
Parameters
104103
----------
105104
**parameters : dict
106-
Estimator parameters to be set.
105+
Estimator parameters.
107106
108107
Returns
109108
-------
110109
self : object
111-
The instance with updated parameters.
110+
Estimator instance.
112111
"""
113-
# Update config_kwargs with provided parameters
114-
valid_config_keys = self.config_kwargs.keys()
115-
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
116-
self.config_kwargs.update(config_updates)
117-
118-
# Update the config object
119-
for key, value in config_updates.items():
120-
setattr(self.config, key, value)
121-
122-
# Handle preprocessor parameters (prefixed with 'preprocessor__')
112+
config_params = {
113+
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
114+
}
123115
preprocessor_params = {
124116
k.split("__")[1]: v
125117
for k, v in parameters.items()
126118
if k.startswith("preprocessor__")
127119
}
120+
121+
if config_params:
122+
self.config_kwargs.update(config_params)
123+
if self.config is not None:
124+
for key, value in config_params.items():
125+
setattr(self.config, key, value)
126+
else:
127+
self.config = self.config_class(**self.config_kwargs)
128+
128129
if preprocessor_params:
129-
# Assuming Preprocessor has a set_params method
130130
self.preprocessor.set_params(**preprocessor_params)
131131

132132
return self

mambular/models/sklearn_base_regressor.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from ..data_utils.datamodule import MambularDataModule
1010
from ..preprocessing import Preprocessor
1111
from lightning.pytorch.callbacks import ModelSummary
12+
from dataclasses import asdict, is_dataclass
1213

1314

1415
class SklearnBaseRegressor(BaseEstimator):
1516
def __init__(self, model, config, **kwargs):
16-
preprocessor_arg_names = [
17+
self.preprocessor_arg_names = [
1718
"n_bins",
1819
"numerical_preprocessing",
1920
"use_decision_tree_bins",
@@ -26,16 +27,18 @@ def __init__(self, model, config, **kwargs):
2627
]
2728

2829
self.config_kwargs = {
29-
k: v for k, v in kwargs.items() if k not in preprocessor_arg_names
30+
k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names
3031
}
3132
self.config = config(**self.config_kwargs)
3233

3334
preprocessor_kwargs = {
34-
k: v for k, v in kwargs.items() if k in preprocessor_arg_names
35+
k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
3536
}
3637

3738
self.preprocessor = Preprocessor(**preprocessor_kwargs)
39+
self.base_model = model
3840
self.model = None
41+
self.built = False
3942

4043
# Raise a warning if task is set to 'classification'
4144
if preprocessor_kwargs.get("task") == "classification":
@@ -44,27 +47,24 @@ def __init__(self, model, config, **kwargs):
4447
UserWarning,
4548
)
4649

47-
self.base_model = model
48-
4950
def get_params(self, deep=True):
5051
"""
51-
Get parameters for this estimator. Overrides the BaseEstimator method.
52+
Get parameters for this estimator.
5253
5354
Parameters
5455
----------
5556
deep : bool, default=True
56-
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
57+
If True, will return the parameters for this estimator and contained subobjects that are estimators.
5758
5859
Returns
5960
-------
6061
params : dict
6162
Parameter names mapped to their values.
6263
"""
63-
params = self.config_kwargs # Parameters used to initialize DefaultConfig
64+
params = {}
65+
params.update(self.config_kwargs)
6466

65-
# If deep=True, include parameters from nested components like preprocessor
6667
if deep:
67-
# Assuming Preprocessor has a get_params method
6868
preprocessor_params = {
6969
"preprocessor__" + key: value
7070
for key, value in self.preprocessor.get_params().items()
@@ -75,35 +75,36 @@ def get_params(self, deep=True):
7575

7676
def set_params(self, **parameters):
7777
"""
78-
Set the parameters of this estimator. Overrides the BaseEstimator method.
78+
Set the parameters of this estimator.
7979
8080
Parameters
8181
----------
8282
**parameters : dict
83-
Estimator parameters to be set.
83+
Estimator parameters.
8484
8585
Returns
8686
-------
8787
self : object
88-
The instance with updated parameters.
88+
Estimator instance.
8989
"""
90-
# Update config_kwargs with provided parameters
91-
valid_config_keys = self.config_kwargs.keys()
92-
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
93-
self.config_kwargs.update(config_updates)
94-
95-
# Update the config object
96-
for key, value in config_updates.items():
97-
setattr(self.config, key, value)
98-
99-
# Handle preprocessor parameters (prefixed with 'preprocessor__')
90+
config_params = {
91+
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
92+
}
10093
preprocessor_params = {
10194
k.split("__")[1]: v
10295
for k, v in parameters.items()
10396
if k.startswith("preprocessor__")
10497
}
98+
99+
if config_params:
100+
self.config_kwargs.update(config_params)
101+
if self.config is not None:
102+
for key, value in config_params.items():
103+
setattr(self.config, key, value)
104+
else:
105+
self.config = self.config_class(**self.config_kwargs)
106+
105107
if preprocessor_params:
106-
# Assuming Preprocessor has a set_params method
107108
self.preprocessor.set_params(**preprocessor_params)
108109

109110
return self
@@ -471,3 +472,24 @@ def evaluate(self, X, y_true, metrics=None):
471472
scores[metric_name] = metric_func(y_true, predictions)
472473

473474
return scores
475+
476+
def score(self, X, y, metric=mean_squared_error):
477+
"""
478+
Calculate the score of the model using the specified metric.
479+
480+
Parameters
481+
----------
482+
X : array-like or pd.DataFrame of shape (n_samples, n_features)
483+
The input samples to predict.
484+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
485+
The true target values against which to evaluate the predictions.
486+
metric : callable, default=mean_squared_error
487+
The metric function to use for evaluation. Must be a callable with the signature `metric(y_true, y_pred)`.
488+
489+
Returns
490+
-------
491+
score : float
492+
The score calculated using the specified metric.
493+
"""
494+
predictions = self.predict(X)
495+
return metric(y, predictions)

0 commit comments

Comments
 (0)