Skip to content

Commit 19b760c

Browse files
authored
Merge branch 'develop' into layer_improvement
2 parents 283a10b + cc92798 commit 19b760c

7 files changed

Lines changed: 145 additions & 78 deletions

File tree

.github/workflows/build-publish-pypi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Publish Package to PyPi
33
on:
44
push:
55
branches:
6-
- master
6+
- release
77

88
jobs:
99
publish:

mambular/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.1.6"
4+
__version__ = "0.1.7"

mambular/base_models/lightning_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
lss=False,
3838
family=None,
3939
loss_fct: callable = None,
40-
**kwargs
40+
**kwargs,
4141
):
4242
super().__init__()
4343
self.num_classes = num_classes
@@ -126,7 +126,7 @@ def compute_loss(self, predictions, y_true):
126126
Computed loss.
127127
"""
128128
if self.lss:
129-
return self.family.compute_loss(predictions, y_true)
129+
return self.family.compute_loss(predictions, y_true.squeeze(-1))
130130
else:
131131
loss = self.loss_fct(predictions, y_true)
132132
return loss
@@ -300,7 +300,7 @@ def configure_optimizers(self):
300300
A dictionary containing the optimizer and lr_scheduler configurations.
301301
"""
302302
optimizer = torch.optim.Adam(
303-
self.parameters(),
303+
self.model.parameters(),
304304
lr=self.lr,
305305
weight_decay=self.weight_decay,
306306
)

mambular/models/sklearn_base_classifier.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from ..data_utils.datamodule import MambularDataModule
1010
from ..preprocessing import Preprocessor
1111
import numpy as np
12+
from lightning.pytorch.callbacks import ModelSummary
13+
from sklearn.metrics import log_loss
1214

1315

1416
class SklearnBaseClassifier(BaseEstimator):
@@ -49,23 +51,22 @@ def __init__(self, model, config, **kwargs):
4951

5052
def get_params(self, deep=True):
5153
"""
52-
Get parameters for this estimator. Overrides the BaseEstimator method.
54+
Get parameters for this estimator.
5355
5456
Parameters
5557
----------
5658
deep : bool, default=True
57-
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
59+
If True, will return the parameters for this estimator and contained subobjects that are estimators.
5860
5961
Returns
6062
-------
6163
params : dict
6264
Parameter names mapped to their values.
6365
"""
64-
params = self.config_kwargs # Parameters used to initialize DefaultConfig
66+
params = {}
67+
params.update(self.config_kwargs)
6568

66-
# If deep=True, include parameters from nested components like preprocessor
6769
if deep:
68-
# Assuming Preprocessor has a get_params method
6970
preprocessor_params = {
7071
"preprocessor__" + key: value
7172
for key, value in self.preprocessor.get_params().items()
@@ -76,35 +77,36 @@ def get_params(self, deep=True):
7677

7778
def set_params(self, **parameters):
7879
"""
79-
Set the parameters of this estimator. Overrides the BaseEstimator method.
80+
Set the parameters of this estimator.
8081
8182
Parameters
8283
----------
8384
**parameters : dict
84-
Estimator parameters to be set.
85+
Estimator parameters.
8586
8687
Returns
8788
-------
8889
self : object
89-
The instance with updated parameters.
90+
Estimator instance.
9091
"""
91-
# Update config_kwargs with provided parameters
92-
valid_config_keys = self.config_kwargs.keys()
93-
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
94-
self.config_kwargs.update(config_updates)
95-
96-
# Update the config object
97-
for key, value in config_updates.items():
98-
setattr(self.config, key, value)
99-
100-
# Handle preprocessor parameters (prefixed with 'preprocessor__')
92+
config_params = {
93+
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
94+
}
10195
preprocessor_params = {
10296
k.split("__")[1]: v
10397
for k, v in parameters.items()
10498
if k.startswith("preprocessor__")
10599
}
100+
101+
if config_params:
102+
self.config_kwargs.update(config_params)
103+
if self.config is not None:
104+
for key, value in config_params.items():
105+
setattr(self.config, key, value)
106+
else:
107+
self.config = self.config_class(**self.config_kwargs)
108+
106109
if preprocessor_params:
107-
# Assuming Preprocessor has a set_params method
108110
self.preprocessor.set_params(**preprocessor_params)
109111

110112
return self
@@ -368,12 +370,16 @@ def fit(
368370
)
369371

370372
# Initialize the trainer and train the model
371-
trainer = pl.Trainer(
373+
self.trainer = pl.Trainer(
372374
max_epochs=max_epochs,
373-
callbacks=[early_stop_callback, checkpoint_callback],
375+
callbacks=[
376+
early_stop_callback,
377+
checkpoint_callback,
378+
ModelSummary(max_depth=2),
379+
],
374380
**trainer_kwargs
375381
)
376-
trainer.fit(self.model, self.data_module)
382+
self.trainer.fit(self.model, self.data_module)
377383

378384
best_model_path = checkpoint_callback.best_model_path
379385
if best_model_path:
@@ -555,3 +561,33 @@ def evaluate(self, X, y_true, metrics=None):
555561
scores[metric_name] = metric_func(y_true, predictions)
556562

557563
return scores
564+
565+
def score(self, X, y, metric=(log_loss, True)):
566+
"""
567+
Calculate the score of the model using the specified metric.
568+
569+
Parameters
570+
----------
571+
X : array-like or pd.DataFrame of shape (n_samples, n_features)
572+
The input samples to predict.
573+
y : array-like of shape (n_samples,)
574+
The true class labels against which to evaluate the predictions.
575+
metric : tuple, default=(log_loss, True)
576+
A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
577+
578+
Returns
579+
-------
580+
score : float
581+
The score calculated using the specified metric.
582+
"""
583+
metric_func, use_proba = metric
584+
585+
if not isinstance(X, pd.DataFrame):
586+
X = pd.DataFrame(X)
587+
588+
if use_proba:
589+
probabilities = self.predict_proba(X)
590+
return metric_func(y, probabilities)
591+
else:
592+
predictions = self.predict(X)
593+
return metric_func(y, predictions)

mambular/models/sklearn_base_lss.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
PoissonDistribution,
3232
StudentTDistribution,
3333
)
34+
from lightning.pytorch.callbacks import ModelSummary
3435

3536

3637
class SklearnBaseLSS(BaseEstimator):
@@ -70,23 +71,22 @@ def __init__(self, model, config, **kwargs):
7071

7172
def get_params(self, deep=True):
7273
"""
73-
Get parameters for this estimator. Overrides the BaseEstimator method.
74+
Get parameters for this estimator.
7475
7576
Parameters
7677
----------
7778
deep : bool, default=True
78-
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
79+
If True, will return the parameters for this estimator and contained subobjects that are estimators.
7980
8081
Returns
8182
-------
8283
params : dict
8384
Parameter names mapped to their values.
8485
"""
85-
params = self.config_kwargs # Parameters used to initialize DefaultConfig
86+
params = {}
87+
params.update(self.config_kwargs)
8688

87-
# If deep=True, include parameters from nested components like preprocessor
8889
if deep:
89-
# Assuming Preprocessor has a get_params method
9090
preprocessor_params = {
9191
"preprocessor__" + key: value
9292
for key, value in self.preprocessor.get_params().items()
@@ -97,35 +97,36 @@ def get_params(self, deep=True):
9797

9898
def set_params(self, **parameters):
9999
"""
100-
Set the parameters of this estimator. Overrides the BaseEstimator method.
100+
Set the parameters of this estimator.
101101
102102
Parameters
103103
----------
104104
**parameters : dict
105-
Estimator parameters to be set.
105+
Estimator parameters.
106106
107107
Returns
108108
-------
109109
self : object
110-
The instance with updated parameters.
110+
Estimator instance.
111111
"""
112-
# Update config_kwargs with provided parameters
113-
valid_config_keys = self.config_kwargs.keys()
114-
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
115-
self.config_kwargs.update(config_updates)
116-
117-
# Update the config object
118-
for key, value in config_updates.items():
119-
setattr(self.config, key, value)
120-
121-
# Handle preprocessor parameters (prefixed with 'preprocessor__')
112+
config_params = {
113+
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
114+
}
122115
preprocessor_params = {
123116
k.split("__")[1]: v
124117
for k, v in parameters.items()
125118
if k.startswith("preprocessor__")
126119
}
120+
121+
if config_params:
122+
self.config_kwargs.update(config_params)
123+
if self.config is not None:
124+
for key, value in config_params.items():
125+
setattr(self.config, key, value)
126+
else:
127+
self.config = self.config_class(**self.config_kwargs)
128+
127129
if preprocessor_params:
128-
# Assuming Preprocessor has a set_params method
129130
self.preprocessor.set_params(**preprocessor_params)
130131

131132
return self
@@ -409,12 +410,16 @@ def fit(
409410
)
410411

411412
# Initialize the trainer and train the model
412-
trainer = pl.Trainer(
413+
self.trainer = pl.Trainer(
413414
max_epochs=max_epochs,
414-
callbacks=[early_stop_callback, checkpoint_callback],
415+
callbacks=[
416+
early_stop_callback,
417+
checkpoint_callback,
418+
ModelSummary(max_depth=2),
419+
],
415420
**trainer_kwargs
416421
)
417-
trainer.fit(self.model, self.data_module)
422+
self.trainer.fit(self.model, self.data_module)
418423

419424
best_model_path = checkpoint_callback.best_model_path
420425
if best_model_path:

0 commit comments

Comments
 (0)