Skip to content

Commit 40fef33

Browse files
committed
include encoding function in sklearn base classes
1 parent d08af31 commit 40fef33

3 files changed

Lines changed: 157 additions & 2 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
get_search_space,
1919
round_to_nearest_16,
2020
)
21+
from tqdm import tqdm
22+
from torch.utils.data import DataLoader
2123

2224

2325
class SklearnBaseClassifier(BaseEstimator):
@@ -176,8 +178,12 @@ def build_model(
176178
Learning rate for the optimizer.
177179
lr_patience : int, default=10
178180
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
179-
factor : float, default=0.1
181+
lr_factor : float, default=0.1
180182
Factor by which the learning rate will be reduced.
183+
train_metrics : dict, default=None
184+
torch.metrics dict to be logged during training.
185+
val_metrics : dict, default=None
186+
torch.metrics dict to be logged during validation.
181187
weight_decay : float, default=0.025
182188
Weight decay (L2 penalty) coefficient.
183189
dataloader_kwargs: dict, default={}
@@ -336,6 +342,10 @@ def fit(
336342
Weight decay (L2 penalty) coefficient.
337343
checkpoint_path : str, default="model_checkpoints"
338344
Path where the checkpoints are being saved.
345+
train_metrics : dict, default=None
346+
torch.metrics dict to be logged during training.
347+
val_metrics : dict, default=None
348+
torch.metrics dict to be logged during validation.
339349
dataloader_kwargs: dict, default={}
340350
The kwargs for the pytorch dataloader class.
341351
rebuild: bool, default=True
@@ -578,6 +588,47 @@ def score(self, X, y, metric=(log_loss, True)):
578588
predictions = self.predict(X)
579589
return metric_func(y, predictions)
580590

591+
def encode(self, X, batch_size=64):
592+
"""
593+
Encodes input data using the trained model's embedding layer.
594+
595+
Parameters
596+
----------
597+
X : array-like or DataFrame
598+
Input data to be encoded.
599+
batch_size : int, optional, default=64
600+
Batch size for encoding.
601+
602+
Returns
603+
-------
604+
torch.Tensor
605+
Encoded representations of the input data.
606+
607+
Raises
608+
------
609+
ValueError
610+
If the model or data module is not fitted.
611+
"""
612+
# Ensure model and data module are initialized
613+
if self.task_model is None or self.data_module is None:
614+
raise ValueError("The model or data module has not been fitted yet.")
615+
encoded_dataset = self.data_module.preprocess_new_data(X)
616+
617+
data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
618+
619+
# Process data in batches
620+
encoded_outputs = []
621+
for num_features, cat_features in tqdm(data_loader):
622+
embeddings = self.task_model.base_model.encode(
623+
num_features, cat_features
624+
) # Call your encode function
625+
encoded_outputs.append(embeddings)
626+
627+
# Concatenate all encoded outputs
628+
encoded_outputs = torch.cat(encoded_outputs, dim=0)
629+
630+
return encoded_outputs
631+
581632
def optimize_hparams(
582633
self,
583634
X,

mambular/models/sklearn_base_lss.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
Quantile,
4040
StudentTDistribution,
4141
)
42+
from tqdm import tqdm
43+
from torch.utils.data import DataLoader
4244

4345

4446
class SklearnBaseLSS(BaseEstimator):
@@ -198,8 +200,12 @@ def build_model(
198200
Learning rate for the optimizer.
199201
lr_patience : int, default=10
200202
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
201-
factor : float, default=0.1
203+
lr_factor : float, default=0.1
202204
Factor by which the learning rate will be reduced.
205+
train_metrics : dict, default=None
206+
torch.metrics dict to be logged during training.
207+
val_metrics : dict, default=None
208+
torch.metrics dict to be logged during validation.
203209
weight_decay : float, default=0.025
204210
Weight decay (L2 penalty) coefficient.
205211
dataloader_kwargs: dict, default={}
@@ -361,6 +367,10 @@ def fit(
361367
Weight decay (L2 penalty) coefficient.
362368
distributional_kwargs : dict, default=None
363369
any arguments taht are specific for a certain distribution.
370+
train_metrics : dict, default=None
371+
torch.metrics dict to be logged during training.
372+
val_metrics : dict, default=None
373+
torch.metrics dict to be logged during validation.
364374
checkpoint_path : str, default="model_checkpoints"
365375
Path where the checkpoints are being saved.
366376
dataloader_kwargs: dict, default={}
@@ -596,6 +606,47 @@ def score(self, X, y, metric="NLL"):
596606
score = self.task_model.family.evaluate_nll(y, predictions) # type: ignore
597607
return score
598608

609+
def encode(self, X, batch_size=64):
610+
"""
611+
Encodes input data using the trained model's embedding layer.
612+
613+
Parameters
614+
----------
615+
X : array-like or DataFrame
616+
Input data to be encoded.
617+
batch_size : int, optional, default=64
618+
Batch size for encoding.
619+
620+
Returns
621+
-------
622+
torch.Tensor
623+
Encoded representations of the input data.
624+
625+
Raises
626+
------
627+
ValueError
628+
If the model or data module is not fitted.
629+
"""
630+
# Ensure model and data module are initialized
631+
if self.task_model is None or self.data_module is None:
632+
raise ValueError("The model or data module has not been fitted yet.")
633+
encoded_dataset = self.data_module.preprocess_new_data(X)
634+
635+
data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
636+
637+
# Process data in batches
638+
encoded_outputs = []
639+
for num_features, cat_features in tqdm(data_loader):
640+
embeddings = self.task_model.base_model.encode(
641+
num_features, cat_features
642+
) # Call your encode function
643+
encoded_outputs.append(embeddings)
644+
645+
# Concatenate all encoded outputs
646+
encoded_outputs = torch.cat(encoded_outputs, dim=0)
647+
648+
return encoded_outputs
649+
599650
def optimize_hparams(
600651
self,
601652
X,

mambular/models/sklearn_base_regressor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
get_search_space,
1717
round_to_nearest_16,
1818
)
19+
from torch.utils.data import DataLoader
20+
from tqdm import tqdm
1921

2022

2123
class SklearnBaseRegressor(BaseEstimator):
@@ -178,6 +180,10 @@ def build_model(
178180
Factor by which the learning rate will be reduced.
179181
weight_decay : float, default=0.025
180182
Weight decay (L2 penalty) coefficient.
183+
train_metrics : dict, default=None
184+
torch.metrics dict to be logged during training.
185+
val_metrics : dict, default=None
186+
torch.metrics dict to be logged during validation.
181187
dataloader_kwargs: dict, default={}
182188
The kwargs for the pytorch dataloader class.
183189
@@ -333,6 +339,12 @@ def fit(
333339
Path where the checkpoints are being saved.
334340
dataloader_kwargs: dict, default={}
335341
The kwargs for the pytorch dataloader class.
342+
train_metrics : dict, default=None
343+
torch.metrics dict to be logged during training.
344+
val_metrics : dict, default=None
345+
torch.metrics dict to be logged during validation.
346+
rebuild: bool, default=True
347+
Whether to rebuild the model when it already was built.
336348
**trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
337349
338350
@@ -492,6 +504,47 @@ def score(self, X, y, metric=mean_squared_error):
492504
predictions = self.predict(X)
493505
return metric(y, predictions)
494506

507+
def encode(self, X, batch_size=64):
508+
"""
509+
Encodes input data using the trained model's embedding layer.
510+
511+
Parameters
512+
----------
513+
X : array-like or DataFrame
514+
Input data to be encoded.
515+
batch_size : int, optional, default=64
516+
Batch size for encoding.
517+
518+
Returns
519+
-------
520+
torch.Tensor
521+
Encoded representations of the input data.
522+
523+
Raises
524+
------
525+
ValueError
526+
If the model or data module is not fitted.
527+
"""
528+
# Ensure model and data module are initialized
529+
if self.task_model is None or self.data_module is None:
530+
raise ValueError("The model or data module has not been fitted yet.")
531+
encoded_dataset = self.data_module.preprocess_new_data(X)
532+
533+
data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
534+
535+
# Process data in batches
536+
encoded_outputs = []
537+
for num_features, cat_features in tqdm(data_loader):
538+
embeddings = self.task_model.base_model.encode(
539+
num_features, cat_features
540+
) # Call your encode function
541+
encoded_outputs.append(embeddings)
542+
543+
# Concatenate all encoded outputs
544+
encoded_outputs = torch.cat(encoded_outputs, dim=0)
545+
546+
return encoded_outputs
547+
495548
def optimize_hparams(
496549
self,
497550
X,

0 commit comments

Comments
 (0)