|
18 | 18 | get_search_space, |
19 | 19 | round_to_nearest_16, |
20 | 20 | ) |
| 21 | +from tqdm import tqdm |
| 22 | +from torch.utils.data import DataLoader |
21 | 23 |
|
22 | 24 |
|
23 | 25 | class SklearnBaseClassifier(BaseEstimator): |
@@ -176,8 +178,12 @@ def build_model( |
176 | 178 | Learning rate for the optimizer. |
177 | 179 | lr_patience : int, default=10 |
178 | 180 | Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. |
179 | | - factor : float, default=0.1 |
| 181 | + lr_factor : float, default=0.1 |
180 | 182 | Factor by which the learning rate will be reduced. |
| 183 | + train_metrics : dict, default=None |
| 184 | + torch.metrics dict to be logged during training. |
| 185 | + val_metrics : dict, default=None |
| 186 | + torch.metrics dict to be logged during validation. |
181 | 187 | weight_decay : float, default=0.025 |
182 | 188 | Weight decay (L2 penalty) coefficient. |
183 | 189 | dataloader_kwargs: dict, default={} |
@@ -336,6 +342,10 @@ def fit( |
336 | 342 | Weight decay (L2 penalty) coefficient. |
337 | 343 | checkpoint_path : str, default="model_checkpoints" |
338 | 344 | Path where the checkpoints are being saved. |
| 345 | + train_metrics : dict, default=None |
| 346 | + torch.metrics dict to be logged during training. |
| 347 | + val_metrics : dict, default=None |
| 348 | + torch.metrics dict to be logged during validation. |
339 | 349 | dataloader_kwargs: dict, default={} |
340 | 350 | The kwargs for the pytorch dataloader class. |
341 | 351 | rebuild: bool, default=True |
@@ -578,6 +588,47 @@ def score(self, X, y, metric=(log_loss, True)): |
578 | 588 | predictions = self.predict(X) |
579 | 589 | return metric_func(y, predictions) |
580 | 590 |
|
| 591 | + def encode(self, X, batch_size=64): |
| 592 | + """ |
| 593 | + Encodes input data using the trained model's embedding layer. |
| 594 | +
|
| 595 | + Parameters |
| 596 | + ---------- |
| 597 | + X : array-like or DataFrame |
| 598 | + Input data to be encoded. |
| 599 | + batch_size : int, optional, default=64 |
| 600 | + Batch size for encoding. |
| 601 | +
|
| 602 | + Returns |
| 603 | + ------- |
| 604 | + torch.Tensor |
| 605 | + Encoded representations of the input data. |
| 606 | +
|
| 607 | + Raises |
| 608 | + ------ |
| 609 | + ValueError |
| 610 | + If the model or data module is not fitted. |
| 611 | + """ |
| 612 | + # Ensure model and data module are initialized |
| 613 | + if self.task_model is None or self.data_module is None: |
| 614 | + raise ValueError("The model or data module has not been fitted yet.") |
| 615 | + encoded_dataset = self.data_module.preprocess_new_data(X) |
| 616 | + |
| 617 | + data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False) |
| 618 | + |
| 619 | + # Process data in batches |
| 620 | + encoded_outputs = [] |
| 621 | + for num_features, cat_features in tqdm(data_loader): |
| 622 | + embeddings = self.task_model.base_model.encode( |
| 623 | + num_features, cat_features |
| 624 | + ) # Call your encode function |
| 625 | + encoded_outputs.append(embeddings) |
| 626 | + |
| 627 | + # Concatenate all encoded outputs |
| 628 | + encoded_outputs = torch.cat(encoded_outputs, dim=0) |
| 629 | + |
| 630 | + return encoded_outputs |
| 631 | + |
581 | 632 | def optimize_hparams( |
582 | 633 | self, |
583 | 634 | X, |
|
0 commit comments