Skip to content

Commit 6fb616e

Browse files
committed
rename base_model to self.estimator
1 parent 180f126 commit 6fb616e

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

mambular/base_models/utils/lightning_wrapper.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
else:
9090
output_dim = num_classes
9191

92-
self.base_model = model_class(
92+
self.estimator = model_class(
9393
config=config,
9494
feature_information=feature_information,
9595
num_classes=output_dim,
@@ -112,7 +112,7 @@ def forward(self, num_features, cat_features, embeddings):
112112
Model output.
113113
"""
114114

115-
return self.base_model.forward(num_features, cat_features, embeddings)
115+
return self.estimator.forward(num_features, cat_features, embeddings)
116116

117117
def compute_loss(self, predictions, y_true):
118118
"""Compute the loss for the given predictions and true labels.
@@ -130,7 +130,7 @@ def compute_loss(self, predictions, y_true):
130130
Computed loss.
131131
"""
132132
if self.lss:
133-
if getattr(self.base_model, "returns_ensemble", False):
133+
if getattr(self.estimator, "returns_ensemble", False):
134134
loss = 0.0
135135
for ensemble_member in range(predictions.shape[1]):
136136
loss += self.family.compute_loss( # type: ignore
@@ -143,7 +143,7 @@ def compute_loss(self, predictions, y_true):
143143
y_true.squeeze(-1),
144144
)
145145

146-
if getattr(self.base_model, "returns_ensemble", False): # Ensemble case
146+
if getattr(self.estimator, "returns_ensemble", False): # Ensemble case
147147
if (
148148
self.loss_fct.__class__.__name__ == "CrossEntropyLoss"
149149
and predictions.dim() == 3
@@ -191,8 +191,8 @@ def training_step(self, batch, batch_idx): # type: ignore
191191
data, labels = batch
192192

193193
# Check if the model has a `penalty_forward` method
194-
if hasattr(self.base_model, "penalty_forward"):
195-
preds, penalty = self.base_model.penalty_forward(*data)
194+
if hasattr(self.estimator, "penalty_forward"):
195+
preds, penalty = self.estimator.penalty_forward(*data)
196196
loss = self.compute_loss(preds, labels) + penalty
197197
else:
198198
preds = self(*data)
@@ -396,7 +396,7 @@ def configure_optimizers(self): # type: ignore
396396

397397
# Initialize the optimizer with the chosen class and parameters
398398
optimizer = optimizer_class(
399-
self.base_model.parameters(),
399+
self.estimator.parameters(),
400400
lr=self.lr,
401401
weight_decay=self.weight_decay,
402402
**self.optimizer_params, # Pass any additional optimizer-specific parameters
@@ -443,9 +443,9 @@ def pretrain_embeddings(
443443
Path to save the pretrained embeddings.
444444
"""
445445
print("🚀 Pretraining embeddings...")
446-
self.base_model.train()
446+
self.estimator.train()
447447

448-
optimizer = torch.optim.Adam(self.base_model.embedding_parameters(), lr=lr)
448+
optimizer = torch.optim.Adam(self.estimator.embedding_parameters(), lr=lr)
449449

450450
# 🔥 Single tqdm progress bar across all epochs and batches
451451
total_batches = pretrain_epochs * len(train_dataloader)
@@ -459,7 +459,7 @@ def pretrain_embeddings(
459459
optimizer.zero_grad()
460460

461461
# Forward pass through embeddings only
462-
embeddings = self.base_model.encode(data, grad=True)
462+
embeddings = self.estimator.encode(data, grad=True)
463463

464464
# Compute nearest neighbors based on task type
465465
knn_indices = self.get_knn(labels, k_neighbors, regression)
@@ -481,7 +481,7 @@ def pretrain_embeddings(
481481
progress_bar.close()
482482

483483
# Save pretrained embeddings
484-
torch.save(self.base_model.get_embedding_state_dict(), save_path)
484+
torch.save(self.estimator.get_embedding_state_dict(), save_path)
485485
print(f"✅ Embeddings saved to {save_path}")
486486

487487
def get_knn(self, labels, k_neighbors=5, regression=True, device=""):

0 commit comments

Comments
 (0)