Skip to content

Commit ec989f1

Browse files
committed
rename to self.estimator
1 parent 6fb616e commit ec989f1

5 files changed

Lines changed: 26 additions & 31 deletions

File tree

mambular/base_models/utils/pretraining.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(
2020
pool_sequence=True,
2121
):
2222
super().__init__()
23-
self.base_model = base_model
24-
self.base_model.eval()
23+
self.estimator = base_model
24+
self.estimator.eval()
2525
self.k_neighbors = k_neighbors
2626
self.temperature = temperature
2727
self.lr = lr
@@ -33,9 +33,9 @@ def __init__(
3333
self.loss_fn = nn.CosineEmbeddingLoss(margin=margin, reduction="mean")
3434

3535
def forward(self, x):
36-
x = self.base_model.encode(x, grad=True)
36+
x = self.estimator.encode(x, grad=True)
3737
if self.pool_sequence:
38-
return self.base_model.pool_sequence(x)
38+
return self.estimator.pool_sequence(x)
3939
return x # Return unpooled sequence embeddings (N, S, D)
4040

4141
def get_knn(self, labels):
@@ -140,7 +140,7 @@ def contrastive_loss(self, embeddings, knn_indices, neg_indices):
140140

141141
def training_step(self, batch, batch_idx):
142142

143-
self.base_model.embedding_layer.train()
143+
self.estimator.embedding_layer.train()
144144

145145
data, labels = batch
146146
embeddings = self(data)
@@ -173,7 +173,7 @@ def validation_step(self, batch, batch_idx):
173173
return loss
174174

175175
def configure_optimizers(self):
176-
params = chain(self.base_model.parameters())
176+
params = chain(self.estimator.parameters())
177177
return torch.optim.Adam(params, lr=self.lr)
178178

179179

mambular/models/utils/sklearn_base_classifier.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def predict(self, X, embeddings=None, device=None):
248248
logits = torch.cat(logits_list, dim=0) # type: ignore
249249

250250
# Check if ensemble is used
251-
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
251+
if getattr(self.estimator, "returns_ensemble", False): # If using ensemble
252252
logits = logits.mean(dim=1) # Average over ensemble dimension
253253
if logits.dim() == 1: # Ensure correct shape
254254
logits = logits.unsqueeze(1)
@@ -296,7 +296,7 @@ def predict_proba(self, X, embeddings=None, device=None):
296296
logits = torch.cat(logits_list, dim=0)
297297

298298
# Check if ensemble is used
299-
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
299+
if getattr(self.estimator, "returns_ensemble", False): # If using ensemble
300300
logits = logits.mean(dim=1) # Average over ensemble dimension
301301
if logits.dim() == 1: # Ensure correct shape
302302
logits = logits.unsqueeze(1)
@@ -439,7 +439,7 @@ def pretrain(
439439
Notes
440440
-----
441441
- This function requires that `self.build_model()` has been called beforehand.
442-
- The pretraining method uses `self.task_model.base_model.embedding_layer`.
442+
- The pretraining method uses `self.task_model.estimator.embedding_layer`.
443443
- The method invokes `super()._pretrain()` with regression mode enabled.
444444
445445
"""
@@ -448,13 +448,13 @@ def pretrain(
448448
"The model has not been built yet. Call model.build_model(**args) first."
449449
)
450450

451-
if not hasattr(self.task_model.base_model, "embedding_layer"):
451+
if not hasattr(self.task_model.estimator, "embedding_layer"):
452452
raise ValueError("The model does not have an embedding layer")
453453

454454
self.data_module.setup("fit")
455455

456456
super()._pretrain(
457-
self.task_model.base_model,
457+
self.task_model.estimator,
458458
self.data_module,
459459
pretrain_epochs=pretrain_epochs,
460460
k_neighbors=k_neighbors,

mambular/models/utils/sklearn_base_lss.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,13 @@
99
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
1010
from sklearn.base import BaseEstimator
1111
from sklearn.metrics import accuracy_score, mean_squared_error
12-
from skopt import gp_minimize
1312
from torch.utils.data import DataLoader
1413
from tqdm import tqdm
1514

1615
from ...base_models.utils.lightning_wrapper import TaskModel
1716
from ...data_utils.datamodule import MambularDataModule
1817
from ...preprocessing import Preprocessor
19-
from ...utils.config_mapper import (
20-
activation_mapper,
21-
get_search_space,
22-
round_to_nearest_16,
23-
)
18+
2419
from ...utils.distributional_metrics import (
2520
beta_brier_score,
2621
dirichlet_error,
@@ -78,7 +73,7 @@ def __init__(self, model, config, **kwargs):
7873

7974
self.preprocessor = Preprocessor(**preprocessor_kwargs)
8075
self.task_model = None
81-
self.base_model = model
76+
self.estimator = model
8277
self.built = False
8378

8479
# Raise a warning if task is set to 'classification'
@@ -246,7 +241,7 @@ def build_model(
246241
)
247242

248243
self.task_model = TaskModel(
249-
model_class=self.base_model, # type: ignore
244+
model_class=self.estimator, # type: ignore
250245
num_classes=self.family.param_count,
251246
family=self.family,
252247
config=self.config,
@@ -268,7 +263,7 @@ def build_model(
268263
)
269264

270265
self.built = True
271-
self.base_model = self.task_model.base_model
266+
self.estimator = self.task_model.estimator
272267

273268
return self
274269

@@ -497,7 +492,7 @@ def predict(self, X, raw=False, device=None):
497492
predictions = torch.cat(predictions_list, dim=0)
498493

499494
# Check if ensemble is used
500-
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
495+
if getattr(self.estimator, "returns_ensemble", False): # If using ensemble
501496
predictions = predictions.mean(dim=1) # Average over ensemble dimension
502497

503498
if not raw:
@@ -642,7 +637,7 @@ def encode(self, X, batch_size=64):
642637
# Process data in batches
643638
encoded_outputs = []
644639
for num_features, cat_features in tqdm(data_loader):
645-
embeddings = self.task_model.base_model.encode(
640+
embeddings = self.task_model.estimator.encode(
646641
num_features, cat_features
647642
) # Call your encode function
648643
encoded_outputs.append(embeddings)

mambular/models/utils/sklearn_base_regressor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def predict(self, X, embeddings=None, device=None):
249249

250250
# Check if ensemble is used
251251
if getattr(
252-
self.task_model.base_model, "returns_ensemble", False
252+
self.task_model.estimator, "returns_ensemble", False
253253
): # If using ensemble
254254
predictions = predictions.mean(dim=1) # Average over ensemble dimension
255255

@@ -360,7 +360,7 @@ def pretrain(
360360
Notes
361361
-----
362362
- This function requires that `self.build_model()` has been called beforehand.
363-
- The pretraining method uses `self.task_model.base_model.embedding_layer`.
363+
- The pretraining method uses `self.task_model.estimator.embedding_layer`.
364364
- The method invokes `super()._pretrain()` with regression mode enabled.
365365
366366
"""
@@ -369,13 +369,13 @@ def pretrain(
369369
"The model has not been built yet. Call model.build_model(**args) first."
370370
)
371371

372-
if not hasattr(self.task_model.base_model, "embedding_layer"):
372+
if not hasattr(self.task_model.estimator, "embedding_layer"):
373373
raise ValueError("The model does not have an embedding layer")
374374

375375
self.data_module.setup("fit")
376376

377377
super()._pretrain(
378-
self.task_model.base_model,
378+
self.task_model.estimator,
379379
self.data_module,
380380
pretrain_epochs=pretrain_epochs,
381381
k_neighbors=k_neighbors,

mambular/models/utils/sklearn_parent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, model, config, **kwargs):
5454
}
5555

5656
self.preprocessor = Preprocessor(**self.preprocessor_kwargs)
57-
self.base_model = model
57+
self.estimator = model
5858
self.task_model = None
5959
self.built = False
6060

@@ -208,7 +208,7 @@ def _build_model(
208208
)
209209

210210
self.task_model = TaskModel(
211-
model_class=self.base_model, # type: ignore
211+
model_class=self.estimator, # type: ignore
212212
config=self.config,
213213
feature_information=(
214214
self.data_module.num_feature_info,
@@ -230,7 +230,7 @@ def _build_model(
230230
)
231231

232232
self.built = True
233-
self.base_model = self.task_model.base_model
233+
self.estimator = self.task_model.estimator
234234

235235
return self
236236

@@ -399,7 +399,7 @@ def fit(
399399
**trainer_kwargs,
400400
)
401401
self.task_model.train()
402-
self.task_model.base_model.train()
402+
self.task_model.estimator.train()
403403
self.trainer.fit(self.task_model, self.data_module) # type: ignore
404404

405405
self.best_model_path = checkpoint_callback.best_model_path
@@ -458,7 +458,7 @@ def encode(self, X, embeddings=None, batch_size=64):
458458
# Process data in batches
459459
encoded_outputs = []
460460
for batch in tqdm(data_loader):
461-
embeddings = self.task_model.base_model.encode(
461+
embeddings = self.task_model.estimator.encode(
462462
batch
463463
) # Call your encode function
464464
encoded_outputs.append(embeddings)

0 commit comments

Comments
 (0)