@@ -89,7 +89,7 @@ def __init__(
8989 else :
9090 output_dim = num_classes
9191
92- self .base_model = model_class (
92+ self .estimator = model_class (
9393 config = config ,
9494 feature_information = feature_information ,
9595 num_classes = output_dim ,
@@ -112,7 +112,7 @@ def forward(self, num_features, cat_features, embeddings):
112112 Model output.
113113 """
114114
115- return self .base_model .forward (num_features , cat_features , embeddings )
115+ return self .estimator .forward (num_features , cat_features , embeddings )
116116
117117 def compute_loss (self , predictions , y_true ):
118118 """Compute the loss for the given predictions and true labels.
@@ -130,7 +130,7 @@ def compute_loss(self, predictions, y_true):
130130 Computed loss.
131131 """
132132 if self .lss :
133- if getattr (self .base_model , "returns_ensemble" , False ):
133+ if getattr (self .estimator , "returns_ensemble" , False ):
134134 loss = 0.0
135135 for ensemble_member in range (predictions .shape [1 ]):
136136 loss += self .family .compute_loss ( # type: ignore
@@ -143,7 +143,7 @@ def compute_loss(self, predictions, y_true):
143143 y_true .squeeze (- 1 ),
144144 )
145145
146- if getattr (self .base_model , "returns_ensemble" , False ): # Ensemble case
146+ if getattr (self .estimator , "returns_ensemble" , False ): # Ensemble case
147147 if (
148148 self .loss_fct .__class__ .__name__ == "CrossEntropyLoss"
149149 and predictions .dim () == 3
@@ -191,8 +191,8 @@ def training_step(self, batch, batch_idx): # type: ignore
191191 data , labels = batch
192192
193193 # Check if the model has a `penalty_forward` method
194- if hasattr (self .base_model , "penalty_forward" ):
195- preds , penalty = self .base_model .penalty_forward (* data )
194+ if hasattr (self .estimator , "penalty_forward" ):
195+ preds , penalty = self .estimator .penalty_forward (* data )
196196 loss = self .compute_loss (preds , labels ) + penalty
197197 else :
198198 preds = self (* data )
@@ -396,7 +396,7 @@ def configure_optimizers(self): # type: ignore
396396
397397 # Initialize the optimizer with the chosen class and parameters
398398 optimizer = optimizer_class (
399- self .base_model .parameters (),
399+ self .estimator .parameters (),
400400 lr = self .lr ,
401401 weight_decay = self .weight_decay ,
402402 ** self .optimizer_params , # Pass any additional optimizer-specific parameters
@@ -443,9 +443,9 @@ def pretrain_embeddings(
443443 Path to save the pretrained embeddings.
444444 """
445445 print ("🚀 Pretraining embeddings..." )
446- self .base_model .train ()
446+ self .estimator .train ()
447447
448- optimizer = torch .optim .Adam (self .base_model .embedding_parameters (), lr = lr )
448+ optimizer = torch .optim .Adam (self .estimator .embedding_parameters (), lr = lr )
449449
450450 # 🔥 Single tqdm progress bar across all epochs and batches
451451 total_batches = pretrain_epochs * len (train_dataloader )
@@ -459,7 +459,7 @@ def pretrain_embeddings(
459459 optimizer .zero_grad ()
460460
461461 # Forward pass through embeddings only
462- embeddings = self .base_model .encode (data , grad = True )
462+ embeddings = self .estimator .encode (data , grad = True )
463463
464464 # Compute nearest neighbors based on task type
465465 knn_indices = self .get_knn (labels , k_neighbors , regression )
@@ -481,7 +481,7 @@ def pretrain_embeddings(
481481 progress_bar .close ()
482482
483483 # Save pretrained embeddings
484- torch .save (self .base_model .get_embedding_state_dict (), save_path )
484+ torch .save (self .estimator .get_embedding_state_dict (), save_path )
485485 print (f"✅ Embeddings saved to { save_path } " )
486486
487487 def get_knn (self , labels , k_neighbors = 5 , regression = True , device = "" ):
0 commit comments