Skip to content

Commit 2473e5c

Browse files
committed
chore: auto formatting
1 parent c4df541 commit 2473e5c

8 files changed

Lines changed: 117 additions & 318 deletions

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
23
import lightning as pl
34
import torch
45
import torch.nn as nn
@@ -144,10 +145,7 @@ def compute_loss(self, predictions, y_true):
144145
)
145146

146147
if getattr(self.base_model, "returns_ensemble", False): # Ensemble case
147-
if (
148-
self.loss_fct.__class__.__name__ == "CrossEntropyLoss"
149-
and predictions.dim() == 3
150-
):
148+
if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3:
151149
# Classification case with ensemble: predictions (N, E, k), y_true (N,)
152150
N, E, k = predictions.shape
153151
loss = 0.0
@@ -192,18 +190,14 @@ def training_step(self, batch, batch_idx): # type: ignore
192190

193191
# Check if the model has a `penalty_forward` method
194192
if hasattr(self.base_model, "penalty_forward"):
195-
preds, penalty = self.base_model.penalty_forward(
196-
num_features=num_features, cat_features=cat_features
197-
)
193+
preds, penalty = self.base_model.penalty_forward(num_features=num_features, cat_features=cat_features)
198194
loss = self.compute_loss(preds, labels) + penalty
199195
else:
200196
preds = self(num_features=num_features, cat_features=cat_features)
201197
loss = self.compute_loss(preds, labels)
202198

203199
# Log the training loss
204-
self.log(
205-
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
206-
)
200+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
207201

208202
# Log custom training metrics
209203
for metric_name, metric_fn in self.train_metrics.items():
@@ -352,13 +346,8 @@ def on_validation_epoch_end(self):
352346

353347
# Apply pruning logic if needed
354348
if self.current_epoch >= self.pruning_epoch:
355-
if (
356-
self.early_pruning_threshold is not None
357-
and val_loss_value > self.early_pruning_threshold
358-
):
359-
print(
360-
f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}"
361-
)
349+
if self.early_pruning_threshold is not None and val_loss_value > self.early_pruning_threshold:
350+
print(f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}")
362351
self.trainer.should_stop = True # Stop training early
363352

364353
def epoch_val_loss_at(self, epoch):

mambular/data_utils/dataset.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
from torch.utils.data import Dataset
44

55

6-
import numpy as np
7-
import torch
8-
from torch.utils.data import Dataset
9-
10-
116
class MambularDataset(Dataset):
127
"""Custom dataset for handling structured data with separate categorical and
138
numerical features, tailored for both regression and classification tasks.
@@ -20,9 +15,7 @@ class MambularDataset(Dataset):
2015
regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True.
2116
"""
2217

23-
def __init__(
24-
self, cat_features_list, num_features_list, labels=None, regression=True
25-
):
18+
def __init__(self, cat_features_list, num_features_list, labels=None, regression=True):
2619
self.cat_features_list = cat_features_list # Categorical features tensors
2720
self.num_features_list = num_features_list # Numerical features tensors
2821
self.regression = regression
@@ -56,9 +49,7 @@ def __getitem__(self, idx):
5649
tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical features)
5750
and a single label (if available).
5851
"""
59-
cat_features = [
60-
feature_tensor[idx] for feature_tensor in self.cat_features_list
61-
]
52+
cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list]
6253
num_features = [
6354
torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32)
6455
for feature_tensor in self.num_features_list

mambular/models/sklearn_base_classifier.py

Lines changed: 27 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from collections.abc import Callable
23
from typing import Optional
34

45
import lightning as pl
@@ -9,17 +10,13 @@
910
from sklearn.base import BaseEstimator
1011
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
1112
from skopt import gp_minimize
12-
from collections.abc import Callable
13+
from torch.utils.data import DataLoader
14+
from tqdm import tqdm
15+
1316
from ..base_models.lightning_wrapper import TaskModel
1417
from ..data_utils.datamodule import MambularDataModule
1518
from ..preprocessing import Preprocessor
16-
from ..utils.config_mapper import (
17-
activation_mapper,
18-
get_search_space,
19-
round_to_nearest_16,
20-
)
21-
from tqdm import tqdm
22-
from torch.utils.data import DataLoader
19+
from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16
2320

2421

2522
class SklearnBaseClassifier(BaseEstimator):
@@ -42,15 +39,11 @@ def __init__(self, model, config, **kwargs):
4239
]
4340

4441
self.config_kwargs = {
45-
k: v
46-
for k, v in kwargs.items()
47-
if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
42+
k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
4843
}
4944
self.config = config(**self.config_kwargs)
5045

51-
preprocessor_kwargs = {
52-
k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
53-
}
46+
preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names}
5447

5548
self.preprocessor = Preprocessor(**preprocessor_kwargs)
5649
self.task_model = None
@@ -70,8 +63,7 @@ def __init__(self, model, config, **kwargs):
7063
self.optimizer_kwargs = {
7164
k: v
7265
for k, v in kwargs.items()
73-
if k
74-
not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
66+
if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
7567
and k.startswith("optimizer_")
7668
}
7769

@@ -92,10 +84,7 @@ def get_params(self, deep=True):
9284
params.update(self.config_kwargs)
9385

9486
if deep:
95-
preprocessor_params = {
96-
"prepro__" + key: value
97-
for key, value in self.preprocessor.get_params().items()
98-
}
87+
preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()}
9988
params.update(preprocessor_params)
10089

10190
return params
@@ -113,14 +102,8 @@ def set_params(self, **parameters):
113102
self : object
114103
Estimator instance.
115104
"""
116-
config_params = {
117-
k: v for k, v in parameters.items() if not k.startswith("prepro__")
118-
}
119-
preprocessor_params = {
120-
k.split("__")[1]: v
121-
for k, v in parameters.items()
122-
if k.startswith("prepro__")
123-
}
105+
config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")}
106+
preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")}
124107

125108
if config_params:
126109
self.config_kwargs.update(config_params)
@@ -218,9 +201,7 @@ def build_model(
218201
**dataloader_kwargs,
219202
)
220203

221-
self.data_module.preprocess_data(
222-
X, y, X_val, y_val, val_size=val_size, random_state=random_state
223-
)
204+
self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state)
224205

225206
num_classes = len(np.unique(np.array(y)))
226207

@@ -230,14 +211,10 @@ def build_model(
230211
config=self.config,
231212
cat_feature_info=self.data_module.cat_feature_info,
232213
num_feature_info=self.data_module.num_feature_info,
233-
lr_patience=(
234-
lr_patience if lr_patience is not None else self.config.lr_patience
235-
),
214+
lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
236215
lr=lr if lr is not None else self.config.lr,
237216
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
238-
weight_decay=(
239-
weight_decay if weight_decay is not None else self.config.weight_decay
240-
),
217+
weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
241218
train_metrics=train_metrics,
242219
val_metrics=val_metrics,
243220
optimizer_type=self.optimizer_type,
@@ -268,9 +245,7 @@ def get_number_of_params(self, requires_grad=True):
268245
If the model has not been built prior to calling this method.
269246
"""
270247
if not self.built:
271-
raise ValueError(
272-
"The model must be built before the number of parameters can be estimated"
273-
)
248+
raise ValueError("The model must be built before the number of parameters can be estimated")
274249
else:
275250
if requires_grad:
276251
return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore
@@ -442,7 +417,7 @@ def predict(self, X, device=None):
442417
logits_list = self.trainer.predict(self.task_model, self.data_module)
443418

444419
# Concatenate predictions from all batches
445-
logits = torch.cat(logits_list, dim=0)
420+
logits = torch.cat(logits_list, dim=0) # type: ignore
446421

447422
# Check if ensemble is used
448423
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
@@ -619,9 +594,7 @@ def encode(self, X, batch_size=64):
619594
# Process data in batches
620595
encoded_outputs = []
621596
for num_features, cat_features in tqdm(data_loader):
622-
embeddings = self.task_model.base_model.encode(
623-
num_features, cat_features
624-
) # Call your encode function
597+
embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function
625598
encoded_outputs.append(embeddings)
626599

627600
# Concatenate all encoded outputs
@@ -689,13 +662,9 @@ def optimize_hparams(
689662
best_val_loss = float("inf")
690663

691664
if X_val is not None and y_val is not None:
692-
val_loss = self.evaluate(
693-
X_val, y_val, metrics={"Accuracy": (accuracy_score, False)}
694-
)["Accuracy"]
665+
val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})["Accuracy"]
695666
else:
696-
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
697-
"val_loss"
698-
]
667+
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
699668

700669
best_val_loss = val_loss
701670
best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -721,9 +690,7 @@ def _objective(hyperparams):
721690
if param_value in activation_mapper:
722691
setattr(self.config, key, activation_mapper[param_value])
723692
else:
724-
raise ValueError(
725-
f"Unknown activation function: {param_value}"
726-
)
693+
raise ValueError(f"Unknown activation function: {param_value}")
727694
else:
728695
setattr(self.config, key, param_value)
729696

@@ -732,15 +699,11 @@ def _objective(hyperparams):
732699
self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length]
733700

734701
# Build the model with updated hyperparameters
735-
self.build_model(
736-
X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs
737-
)
702+
self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs)
738703

739704
# Dynamically set the early pruning threshold
740705
if prune_by_epoch:
741-
early_pruning_threshold = (
742-
best_epoch_val_loss * 1.5
743-
) # Prune based on specific epoch loss
706+
early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
744707
else:
745708
# Prune based on the best overall validation loss
746709
early_pruning_threshold = best_val_loss * 1.5
@@ -752,19 +715,15 @@ def _objective(hyperparams):
752715
# Fit the model (limit epochs for faster optimization)
753716
try:
754717
# Wrap the risky operation (model fitting) in a try-except block
755-
self.fit(
756-
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
757-
)
718+
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False)
758719

759720
# Evaluate validation loss
760721
if X_val is not None and y_val is not None:
761722
val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ # type: ignore
762723
"Mean Squared Error"
763724
]
764725
else:
765-
val_loss = self.trainer.validate(self.task_model, self.data_module)[
766-
0
767-
]["val_loss"]
726+
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
768727

769728
# Pruning based on validation loss at specific epoch
770729
epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -781,21 +740,15 @@ def _objective(hyperparams):
781740

782741
except Exception as e:
783742
# Penalize the hyperparameter configuration with a large value
784-
print(
785-
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
786-
)
787-
return (
788-
best_val_loss * 100
789-
) # Large value to discourage this configuration
743+
print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}")
744+
return best_val_loss * 100 # Large value to discourage this configuration
790745

791746
# Perform Bayesian optimization using scikit-optimize
792747
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
793748

794749
# Update the model with the best-found hyperparameters
795750
best_hparams = result.x # type: ignore
796-
head_layer_sizes = (
797-
[] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
798-
)
751+
head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
799752
layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None
800753

801754
# Iterate over the best hyperparameters found by optimization

0 commit comments

Comments
 (0)