Skip to content

Commit 4bbf174

Browse files
committed
reformatting
1 parent fa2c978 commit 4bbf174

1 file changed

Lines changed: 26 additions & 68 deletions

File tree

mambular/models/sklearn_base_regressor.py

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
from ..base_models.lightning_wrapper import TaskModel
1515
from ..data_utils.datamodule import MambularDataModule
1616
from ..preprocessing import Preprocessor
17-
from ..utils.config_mapper import (
18-
activation_mapper,
19-
get_search_space,
20-
round_to_nearest_16,
21-
)
17+
from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16
2218

2319

2420
class SklearnBaseRegressor(BaseEstimator):
@@ -42,15 +38,11 @@ def __init__(self, model, config, **kwargs):
4238
]
4339

4440
self.config_kwargs = {
45-
k: v
46-
for k, v in kwargs.items()
47-
if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
41+
k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
4842
}
4943
self.config = config(**self.config_kwargs)
5044

51-
preprocessor_kwargs = {
52-
k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
53-
}
45+
preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names}
5446

5547
self.preprocessor = Preprocessor(**preprocessor_kwargs)
5648
self.base_model = model
@@ -70,8 +62,7 @@ def __init__(self, model, config, **kwargs):
7062
self.optimizer_kwargs = {
7163
k: v
7264
for k, v in kwargs.items()
73-
if k
74-
not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
65+
if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
7566
and k.startswith("optimizer_")
7667
}
7768

@@ -92,10 +83,7 @@ def get_params(self, deep=True):
9283
params.update(self.config_kwargs)
9384

9485
if deep:
95-
preprocessor_params = {
96-
"prepro__" + key: value
97-
for key, value in self.preprocessor.get_params().items()
98-
}
86+
preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()}
9987
params.update(preprocessor_params)
10088

10189
return params
@@ -113,14 +101,8 @@ def set_params(self, **parameters):
113101
self : object
114102
Estimator instance.
115103
"""
116-
config_params = {
117-
k: v for k, v in parameters.items() if not k.startswith("prepro__")
118-
}
119-
preprocessor_params = {
120-
k.split("__")[1]: v
121-
for k, v in parameters.items()
122-
if k.startswith("prepro__")
123-
}
104+
config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")}
105+
preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")}
124106

125107
if config_params:
126108
self.config_kwargs.update(config_params)
@@ -240,13 +222,9 @@ def build_model(
240222
self.data_module.embedding_feature_info,
241223
),
242224
lr=lr if lr is not None else self.config.lr,
243-
lr_patience=(
244-
lr_patience if lr_patience is not None else self.config.lr_patience
245-
),
225+
lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
246226
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
247-
weight_decay=(
248-
weight_decay if weight_decay is not None else self.config.weight_decay
249-
),
227+
weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
250228
train_metrics=train_metrics,
251229
val_metrics=val_metrics,
252230
optimizer_type=self.optimizer_type,
@@ -277,9 +255,7 @@ def get_number_of_params(self, requires_grad=True):
277255
If the model has not been built prior to calling this method.
278256
"""
279257
if not self.built:
280-
raise ValueError(
281-
"The model must be built before the number of parameters can be estimated"
282-
)
258+
raise ValueError("The model must be built before the number of parameters can be estimated")
283259
else:
284260
if requires_grad:
285261
return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore
@@ -456,7 +432,7 @@ def predict(self, X, embeddings=None, device=None):
456432
predictions_list = self.trainer.predict(self.task_model, self.data_module)
457433

458434
# Concatenate predictions from all batches
459-
predictions = torch.cat(predictions_list, dim=0)
435+
predictions = torch.cat(predictions_list, dim=0) # type: ignore
460436

461437
# Check if ensemble is used
462438
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
@@ -553,9 +529,7 @@ def encode(self, X, embeddings=None, batch_size=64):
553529
# Process data in batches
554530
encoded_outputs = []
555531
for batch in tqdm(data_loader):
556-
embeddings = self.task_model.base_model.encode(
557-
batch
558-
) # Call your encode function
532+
embeddings = self.task_model.base_model.encode(batch) # Call your encode function
559533
encoded_outputs.append(embeddings)
560534

561535
# Concatenate all encoded outputs
@@ -633,13 +607,11 @@ def optimize_hparams(
633607
best_val_loss = float("inf")
634608

635609
if X_val is not None and y_val is not None:
636-
val_loss = self.evaluate(
637-
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
638-
)["Mean Squared Error"]
639-
else:
640-
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
641-
"val_loss"
610+
val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[
611+
"Mean Squared Error"
642612
]
613+
else:
614+
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
643615

644616
best_val_loss = val_loss
645617
best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -665,9 +637,7 @@ def _objective(hyperparams):
665637
if param_value in activation_mapper:
666638
setattr(self.config, key, activation_mapper[param_value])
667639
else:
668-
raise ValueError(
669-
f"Unknown activation function: {param_value}"
670-
)
640+
raise ValueError(f"Unknown activation function: {param_value}")
671641
else:
672642
setattr(self.config, key, param_value)
673643

@@ -689,9 +659,7 @@ def _objective(hyperparams):
689659

690660
# Dynamically set the early pruning threshold
691661
if prune_by_epoch:
692-
early_pruning_threshold = (
693-
best_epoch_val_loss * 1.5
694-
) # Prune based on specific epoch loss
662+
early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
695663
else:
696664
# Prune based on the best overall validation loss
697665
early_pruning_threshold = best_val_loss * 1.5
@@ -702,19 +670,15 @@ def _objective(hyperparams):
702670

703671
try:
704672
# Wrap the risky operation (model fitting) in a try-except block
705-
self.fit(
706-
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
707-
)
673+
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False)
708674

709675
# Evaluate validation loss
710676
if X_val is not None and y_val is not None:
711-
val_loss = self.evaluate(
712-
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
713-
)["Mean Squared Error"]
677+
val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[
678+
"Mean Squared Error"
679+
]
714680
else:
715-
val_loss = self.trainer.validate(self.task_model, self.data_module)[
716-
0
717-
]["val_loss"]
681+
val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
718682

719683
# Pruning based on validation loss at specific epoch
720684
epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -731,21 +695,15 @@ def _objective(hyperparams):
731695

732696
except Exception as e:
733697
# Penalize the hyperparameter configuration with a large value
734-
print(
735-
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
736-
)
737-
return (
738-
best_val_loss * 100
739-
) # Large value to discourage this configuration
698+
print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}")
699+
return best_val_loss * 100 # Large value to discourage this configuration
740700

741701
# Perform Bayesian optimization using scikit-optimize
742702
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
743703

744704
# Update the model with the best-found hyperparameters
745705
best_hparams = result.x # type: ignore
746-
head_layer_sizes = (
747-
[] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
748-
)
706+
head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
749707
layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None
750708

751709
# Iterate over the best hyperparameters found by optimization

0 commit comments

Comments
 (0)