@@ -37,7 +37,7 @@ def __init__(self, model, config, **kwargs):
3737 }
3838
3939 self .preprocessor = Preprocessor (** preprocessor_kwargs )
40- self .model = None
40+ self .task_model = None
4141
4242 # Raise a warning if task is set to 'classification'
4343 if preprocessor_kwargs .get ("task" ) == "regression" :
@@ -194,7 +194,7 @@ def build_model(
194194
195195 num_classes = len (np .unique (y ))
196196
197- self .model = TaskModel (
197+ self .task_model = TaskModel (
198198 model_class = self .base_model ,
199199 num_classes = num_classes ,
200200 config = self .config ,
@@ -237,10 +237,10 @@ def get_number_of_params(self, requires_grad=True):
237237 else :
238238 if requires_grad :
239239 return sum (
240- p .numel () for p in self .model .parameters () if p .requires_grad
240+ p .numel () for p in self .task_model .parameters () if p .requires_grad
241241 )
242242 else :
243- return sum (p .numel () for p in self .model .parameters ())
243+ return sum (p .numel () for p in self .task_model .parameters ())
244244
245245 def fit (
246246 self ,
@@ -345,7 +345,7 @@ def fit(
345345
346346 num_classes = len (np .unique (y ))
347347
348- self .model = TaskModel (
348+ self .task_model = TaskModel (
349349 model_class = self .base_model ,
350350 num_classes = num_classes ,
351351 config = self .config ,
@@ -379,12 +379,12 @@ def fit(
379379 ],
380380 ** trainer_kwargs
381381 )
382- self .trainer .fit (self .model , self .data_module )
382+ self .trainer .fit (self .task_model , self .data_module )
383383
384384 best_model_path = checkpoint_callback .best_model_path
385385 if best_model_path :
386386 checkpoint = torch .load (best_model_path )
387- self .model .load_state_dict (checkpoint ["state_dict" ])
387+ self .task_model .load_state_dict (checkpoint ["state_dict" ])
388388
389389 return self
390390
@@ -404,14 +404,14 @@ def predict(self, X):
404404 The predicted target values.
405405 """
406406 # Ensure model and data module are initialized
407- if self .model is None or self .data_module is None :
407+ if self .task_model is None or self .data_module is None :
408408 raise ValueError ("The model or data module has not been fitted yet." )
409409
410410 # Preprocess the data using the data module
411411 cat_tensors , num_tensors = self .data_module .preprocess_test_data (X )
412412
413413 # Move tensors to appropriate device
414- device = next (self .model .parameters ()).device
414+ device = next (self .task_model .parameters ()).device
415415 if isinstance (cat_tensors , list ):
416416 cat_tensors = [tensor .to (device ) for tensor in cat_tensors ]
417417 else :
@@ -423,11 +423,11 @@ def predict(self, X):
423423 num_tensors = num_tensors .to (device )
424424
425425 # Set model to evaluation mode
426- self .model .eval ()
426+ self .task_model .eval ()
427427
428428 # Perform inference
429429 with torch .no_grad ():
430- logits = self .model (num_features = num_tensors , cat_features = cat_tensors )
430+ logits = self .task_model (num_features = num_tensors , cat_features = cat_tensors )
431431
432432 # Check the shape of the logits to determine binary or multi-class classification
433433 if logits .shape [1 ] == 1 :
@@ -484,7 +484,7 @@ def predict_proba(self, X):
484484 # Preprocess the data
485485 if not isinstance (X , pd .DataFrame ):
486486 X = pd .DataFrame (X )
487- device = next (self .model .parameters ()).device
487+ device = next (self .task_model .parameters ()).device
488488 cat_tensors , num_tensors = self .data_module .preprocess_test_data (X )
489489 if isinstance (cat_tensors , list ):
490490 cat_tensors = [tensor .to (device ) for tensor in cat_tensors ]
@@ -497,11 +497,11 @@ def predict_proba(self, X):
497497 num_tensors = num_tensors .to (device )
498498
499499 # Set the model to evaluation mode
500- self .model .eval ()
500+ self .task_model .eval ()
501501
502502 # Perform inference
503503 with torch .no_grad ():
504- logits = self .model (num_features = num_tensors , cat_features = cat_tensors )
504+ logits = self .task_model (num_features = num_tensors , cat_features = cat_tensors )
505505 if logits .shape [1 ] > 1 :
506506 probabilities = torch .softmax (logits , dim = 1 )
507507 else :
0 commit comments