Skip to content

Commit 8cc3e83

Browse files
committed
adapt first only regressor and classifier to handle embeddings
1 parent b84aa50 commit 8cc3e83

2 files changed

Lines changed: 235 additions & 77 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 122 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
import torch
99
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
1010
from sklearn.base import BaseEstimator
11-
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
11+
from sklearn.metrics import accuracy_score, log_loss
1212
from skopt import gp_minimize
1313
from torch.utils.data import DataLoader
1414
from tqdm import tqdm
1515

1616
from ..base_models.lightning_wrapper import TaskModel
1717
from ..data_utils.datamodule import MambularDataModule
1818
from ..preprocessing import Preprocessor
19-
from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16
19+
from ..utils.config_mapper import (
20+
activation_mapper,
21+
get_search_space,
22+
round_to_nearest_16,
23+
)
2024

2125

2226
class SklearnBaseClassifier(BaseEstimator):
@@ -39,11 +43,15 @@ def __init__(self, model, config, **kwargs):
3943
]
4044

4145
self.config_kwargs = {
42-
k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
46+
k: v
47+
for k, v in kwargs.items()
48+
if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
4349
}
4450
self.config = config(**self.config_kwargs)
4551

46-
preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names}
52+
preprocessor_kwargs = {
53+
k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
54+
}
4755

4856
self.preprocessor = Preprocessor(**preprocessor_kwargs)
4957
self.task_model = None
@@ -63,7 +71,8 @@ def __init__(self, model, config, **kwargs):
6371
self.optimizer_kwargs = {
6472
k: v
6573
for k, v in kwargs.items()
66-
if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
74+
if k
75+
not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
6776
and k.startswith("optimizer_")
6877
}
6978

@@ -84,7 +93,10 @@ def get_params(self, deep=True):
8493
params.update(self.config_kwargs)
8594

8695
if deep:
87-
preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()}
96+
preprocessor_params = {
97+
"prepro__" + key: value
98+
for key, value in self.preprocessor.get_params().items()
99+
}
88100
params.update(preprocessor_params)
89101

90102
return params
@@ -102,8 +114,14 @@ def set_params(self, **parameters):
102114
self : object
103115
Estimator instance.
104116
"""
105-
config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")}
106-
preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")}
117+
config_params = {
118+
k: v for k, v in parameters.items() if not k.startswith("prepro__")
119+
}
120+
preprocessor_params = {
121+
k.split("__")[1]: v
122+
for k, v in parameters.items()
123+
if k.startswith("prepro__")
124+
}
107125

108126
if config_params:
109127
self.config_kwargs.update(config_params)
@@ -125,6 +143,8 @@ def build_model(
125143
val_size: float = 0.2,
126144
X_val=None,
127145
y_val=None,
146+
embeddings=None,
147+
embeddings_val=None,
128148
random_state: int = 101,
129149
batch_size: int = 128,
130150
shuffle: bool = True,
@@ -201,20 +221,36 @@ def build_model(
201221
**dataloader_kwargs,
202222
)
203223

204-
self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state)
224+
self.data_module.preprocess_data(
225+
X,
226+
y,
227+
X_val=X_val,
228+
y_val=y_val,
229+
embeddings_train=embeddings,
230+
embeddings_val=embeddings_val,
231+
val_size=val_size,
232+
random_state=random_state,
233+
)
205234

206235
num_classes = len(np.unique(np.array(y)))
207236

208237
self.task_model = TaskModel(
209238
model_class=self.base_model, # type: ignore
210239
num_classes=num_classes,
211240
config=self.config,
212-
cat_feature_info=self.data_module.cat_feature_info,
213-
num_feature_info=self.data_module.num_feature_info,
214-
lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
241+
feature_information=(
242+
self.data_module.num_feature_info,
243+
self.data_module.cat_feature_info,
244+
self.data_module.embedding_feature_info,
245+
),
246+
lr_patience=(
247+
lr_patience if lr_patience is not None else self.config.lr_patience
248+
),
215249
lr=lr if lr is not None else self.config.lr,
216250
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
217-
weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
251+
weight_decay=(
252+
weight_decay if weight_decay is not None else self.config.weight_decay
253+
),
218254
train_metrics=train_metrics,
219255
val_metrics=val_metrics,
220256
optimizer_type=self.optimizer_type,
@@ -245,7 +281,9 @@ def get_number_of_params(self, requires_grad=True):
245281
If the model has not been built prior to calling this method.
246282
"""
247283
if not self.built:
248-
raise ValueError("The model must be built before the number of parameters can be estimated")
284+
raise ValueError(
285+
"The model must be built before the number of parameters can be estimated"
286+
)
249287
else:
250288
if requires_grad:
251289
return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore
@@ -259,6 +297,8 @@ def fit(
259297
val_size: float = 0.2,
260298
X_val=None,
261299
y_val=None,
300+
embeddings=None,
301+
embeddings_val=None,
262302
max_epochs: int = 100,
263303
random_state: int = 101,
264304
batch_size: int = 128,
@@ -340,6 +380,8 @@ def fit(
340380
val_size=val_size,
341381
X_val=X_val,
342382
y_val=y_val,
383+
embeddings=embeddings,
384+
embeddings_val=embeddings_val,
343385
random_state=random_state,
344386
batch_size=batch_size,
345387
shuffle=shuffle,
@@ -390,7 +432,7 @@ def fit(
390432

391433
return self
392434

393-
def predict(self, X, device=None):
435+
def predict(self, X, embeddings=None, device=None):
394436
"""Predicts target labels for the given input samples.
395437
396438
Parameters
@@ -408,7 +450,7 @@ def predict(self, X, device=None):
408450
raise ValueError("The model or data module has not been fitted yet.")
409451

410452
# Preprocess the data using the data module
411-
self.data_module.assign_predict_dataset(X)
453+
self.data_module.assign_predict_dataset(X, embeddings)
412454

413455
# Set model to evaluation mode
414456
self.task_model.eval()
@@ -438,7 +480,7 @@ def predict(self, X, device=None):
438480
# Convert predictions to NumPy array and return
439481
return predictions.cpu().numpy()
440482

441-
def predict_proba(self, X, device=None):
483+
def predict_proba(self, X, embeddings=None, device=None):
442484
"""Predicts class probabilities for the given input samples.
443485
444486
Parameters
@@ -482,7 +524,7 @@ def predict_proba(self, X, device=None):
482524
# Convert probabilities to NumPy array and return
483525
return probabilities.cpu().numpy()
484526

485-
def evaluate(self, X, y_true, metrics=None):
527+
def evaluate(self, X, y_true, embeddings=None, metrics=None):
486528
"""Evaluate the model on the given data using specified metrics.
487529
488530
Parameters
@@ -491,6 +533,8 @@ def evaluate(self, X, y_true, metrics=None):
491533
The input samples to predict.
492534
y_true : array-like of shape (n_samples,)
493535
The true class labels against which to evaluate the predictions.
536+
embneddings : array-like or list of shape(n_samples, dimension)
537+
List or array with embeddings for unstructured data inputs
494538
metrics : dict
495539
A dictionary where keys are metric names and values are tuples containing the metric function
496540
and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
@@ -518,11 +562,11 @@ def evaluate(self, X, y_true, metrics=None):
518562

519563
# Generate class probabilities if any metric requires them
520564
if any(use_proba for _, use_proba in metrics.values()):
521-
probabilities = self.predict_proba(X)
565+
probabilities = self.predict_proba(X, embeddings)
522566

523567
# Generate class labels if any metric requires them
524568
if any(not use_proba for _, use_proba in metrics.values()):
525-
predictions = self.predict(X)
569+
predictions = self.predict(X, embeddings)
526570

527571
# Compute each metric
528572
for metric_name, (metric_func, use_proba) in metrics.items():
@@ -533,7 +577,7 @@ def evaluate(self, X, y_true, metrics=None):
533577

534578
return scores
535579

536-
def score(self, X, y, metric=(log_loss, True)):
580+
def score(self, X, y, embeddings=None, metric=(log_loss, True)):
537581
"""Calculate the score of the model using the specified metric.
538582
539583
Parameters
@@ -557,13 +601,13 @@ def score(self, X, y, metric=(log_loss, True)):
557601
X = pd.DataFrame(X)
558602

559603
if use_proba:
560-
probabilities = self.predict_proba(X)
604+
probabilities = self.predict_proba(X, embeddings)
561605
return metric_func(y, probabilities)
562606
else:
563-
predictions = self.predict(X)
607+
predictions = self.predict(X, embeddings)
564608
return metric_func(y, predictions)
565609

566-
def encode(self, X, batch_size=64):
610+
def encode(self, X, embeddings=None, batch_size=64):
567611
"""
568612
Encodes input data using the trained model's embedding layer.
569613
@@ -587,14 +631,16 @@ def encode(self, X, batch_size=64):
587631
# Ensure model and data module are initialized
588632
if self.task_model is None or self.data_module is None:
589633
raise ValueError("The model or data module has not been fitted yet.")
590-
encoded_dataset = self.data_module.preprocess_new_data(X)
634+
encoded_dataset = self.data_module.preprocess_new_data(X, embeddings)
591635

592636
data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
593637

594638
# Process data in batches
595639
encoded_outputs = []
596-
for num_features, cat_features in tqdm(data_loader):
597-
embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function
640+
for batch in tqdm(data_loader):
641+
embeddings = self.task_model.base_model.encode(
642+
batch
643+
) # Call your encode function
598644
encoded_outputs.append(embeddings)
599645

600646
# Concatenate all encoded outputs
@@ -608,6 +654,8 @@ def optimize_hparams(
608654
y,
609655
X_val=None,
610656
y_val=None,
657+
embeddings=None,
658+
embeddings_val=None,
611659
time=100,
612660
max_epochs=200,
613661
prune_by_epoch=True,
@@ -658,13 +706,25 @@ def optimize_hparams(
658706
)
659707

660708
# Initial model fitting to get the baseline validation loss
661-
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
709+
self.fit(
710+
X,
711+
y,
712+
X_val=X_val,
713+
y_val=y_val,
714+
embeddings=embeddings,
715+
embeddings_val=embeddings_val,
716+
max_epochs=max_epochs,
717+
)
662718
best_val_loss = float("inf")
663719

664720
if X_val is not None and y_val is not None:
665-
val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})["Accuracy"]
721+
val_loss = self.evaluate(
722+
X_val, y_val, metrics={"Accuracy": (accuracy_score, False)}
723+
)["Accuracy"]
666724
else:
667-
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
725+
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
726+
"val_loss"
727+
]
668728

669729
best_val_loss = val_loss
670730
best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -690,7 +750,9 @@ def _objective(hyperparams):
690750
if param_value in activation_mapper:
691751
setattr(self.config, key, activation_mapper[param_value])
692752
else:
693-
raise ValueError(f"Unknown activation function: {param_value}")
753+
raise ValueError(
754+
f"Unknown activation function: {param_value}"
755+
)
694756
else:
695757
setattr(self.config, key, param_value)
696758

@@ -699,11 +761,15 @@ def _objective(hyperparams):
699761
self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length]
700762

701763
# Build the model with updated hyperparameters
702-
self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs)
764+
self.build_model(
765+
X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs
766+
)
703767

704768
# Dynamically set the early pruning threshold
705769
if prune_by_epoch:
706-
early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
770+
early_pruning_threshold = (
771+
best_epoch_val_loss * 1.5
772+
) # Prune based on specific epoch loss
707773
else:
708774
# Prune based on the best overall validation loss
709775
early_pruning_threshold = best_val_loss * 1.5
@@ -715,15 +781,26 @@ def _objective(hyperparams):
715781
# Fit the model (limit epochs for faster optimization)
716782
try:
717783
# Wrap the risky operation (model fitting) in a try-except block
718-
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False)
784+
self.fit(
785+
X,
786+
y,
787+
X_val=X_val,
788+
y_val=y_val,
789+
embeddings=embeddings,
790+
embeddings_val=embeddings_val,
791+
max_epochs=max_epochs,
792+
rebuild=False,
793+
)
719794

720795
# Evaluate validation loss
721796
if X_val is not None and y_val is not None:
722-
val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ # type: ignore
797+
val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})[ # type: ignore
723798
"Mean Squared Error"
724799
]
725800
else:
726-
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
801+
val_loss = self.trainer.validate(self.task_model, self.data_module)[
802+
0
803+
]["val_loss"]
727804

728805
# Pruning based on validation loss at specific epoch
729806
epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -740,15 +817,21 @@ def _objective(hyperparams):
740817

741818
except Exception as e:
742819
# Penalize the hyperparameter configuration with a large value
743-
print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}")
744-
return best_val_loss * 100 # Large value to discourage this configuration
820+
print(
821+
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
822+
)
823+
return (
824+
best_val_loss * 100
825+
) # Large value to discourage this configuration
745826

746827
# Perform Bayesian optimization using scikit-optimize
747828
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
748829

749830
# Update the model with the best-found hyperparameters
750831
best_hparams = result.x # type: ignore
751-
head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
832+
head_layer_sizes = (
833+
[] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
834+
)
752835
layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None
753836

754837
# Iterate over the best hyperparameters found by optimization

0 commit comments

Comments
 (0)