Skip to content

Commit 71cc68e

Browse files
committed
renaming sklearn class attributes
1 parent 07164f5 commit 71cc68e

3 files changed

Lines changed: 40 additions & 38 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model, config, **kwargs):
3737
}
3838

3939
self.preprocessor = Preprocessor(**preprocessor_kwargs)
40-
self.model = None
40+
self.task_model = None
4141

4242
# Raise a warning if task is set to 'classification'
4343
if preprocessor_kwargs.get("task") == "regression":
@@ -194,7 +194,7 @@ def build_model(
194194

195195
num_classes = len(np.unique(y))
196196

197-
self.model = TaskModel(
197+
self.task_model = TaskModel(
198198
model_class=self.base_model,
199199
num_classes=num_classes,
200200
config=self.config,
@@ -237,10 +237,10 @@ def get_number_of_params(self, requires_grad=True):
237237
else:
238238
if requires_grad:
239239
return sum(
240-
p.numel() for p in self.model.parameters() if p.requires_grad
240+
p.numel() for p in self.task_model.parameters() if p.requires_grad
241241
)
242242
else:
243-
return sum(p.numel() for p in self.model.parameters())
243+
return sum(p.numel() for p in self.task_model.parameters())
244244

245245
def fit(
246246
self,
@@ -345,7 +345,7 @@ def fit(
345345

346346
num_classes = len(np.unique(y))
347347

348-
self.model = TaskModel(
348+
self.task_model = TaskModel(
349349
model_class=self.base_model,
350350
num_classes=num_classes,
351351
config=self.config,
@@ -379,12 +379,12 @@ def fit(
379379
],
380380
**trainer_kwargs
381381
)
382-
self.trainer.fit(self.model, self.data_module)
382+
self.trainer.fit(self.task_model, self.data_module)
383383

384384
best_model_path = checkpoint_callback.best_model_path
385385
if best_model_path:
386386
checkpoint = torch.load(best_model_path)
387-
self.model.load_state_dict(checkpoint["state_dict"])
387+
self.task_model.load_state_dict(checkpoint["state_dict"])
388388

389389
return self
390390

@@ -404,14 +404,14 @@ def predict(self, X):
404404
The predicted target values.
405405
"""
406406
# Ensure model and data module are initialized
407-
if self.model is None or self.data_module is None:
407+
if self.task_model is None or self.data_module is None:
408408
raise ValueError("The model or data module has not been fitted yet.")
409409

410410
# Preprocess the data using the data module
411411
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
412412

413413
# Move tensors to appropriate device
414-
device = next(self.model.parameters()).device
414+
device = next(self.task_model.parameters()).device
415415
if isinstance(cat_tensors, list):
416416
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
417417
else:
@@ -423,11 +423,11 @@ def predict(self, X):
423423
num_tensors = num_tensors.to(device)
424424

425425
# Set model to evaluation mode
426-
self.model.eval()
426+
self.task_model.eval()
427427

428428
# Perform inference
429429
with torch.no_grad():
430-
logits = self.model(num_features=num_tensors, cat_features=cat_tensors)
430+
logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
431431

432432
# Check the shape of the logits to determine binary or multi-class classification
433433
if logits.shape[1] == 1:
@@ -484,7 +484,7 @@ def predict_proba(self, X):
484484
# Preprocess the data
485485
if not isinstance(X, pd.DataFrame):
486486
X = pd.DataFrame(X)
487-
device = next(self.model.parameters()).device
487+
device = next(self.task_model.parameters()).device
488488
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
489489
if isinstance(cat_tensors, list):
490490
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
@@ -497,11 +497,11 @@ def predict_proba(self, X):
497497
num_tensors = num_tensors.to(device)
498498

499499
# Set the model to evaluation mode
500-
self.model.eval()
500+
self.task_model.eval()
501501

502502
# Perform inference
503503
with torch.no_grad():
504-
logits = self.model(num_features=num_tensors, cat_features=cat_tensors)
504+
logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
505505
if logits.shape[1] > 1:
506506
probabilities = torch.softmax(logits, dim=1)
507507
else:

mambular/models/sklearn_base_lss.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, model, config, **kwargs):
5858
}
5959

6060
self.preprocessor = Preprocessor(**preprocessor_kwargs)
61-
self.model = None
61+
self.task_model = None
6262

6363
# Raise a warning if task is set to 'classification'
6464
if preprocessor_kwargs.get("task") == "classification":
@@ -212,7 +212,7 @@ def build_model(
212212

213213
num_classes = len(np.unique(y))
214214

215-
self.model = TaskModel(
215+
self.task_model = TaskModel(
216216
model_class=self.base_model,
217217
num_classes=num_classes,
218218
config=self.config,
@@ -255,10 +255,10 @@ def get_number_of_params(self, requires_grad=True):
255255
else:
256256
if requires_grad:
257257
return sum(
258-
p.numel() for p in self.model.parameters() if p.requires_grad
258+
p.numel() for p in self.task_model.parameters() if p.requires_grad
259259
)
260260
else:
261-
return sum(p.numel() for p in self.model.parameters())
261+
return sum(p.numel() for p in self.task_model.parameters())
262262

263263
def fit(
264264
self,
@@ -383,7 +383,7 @@ def fit(
383383
X, y, X_val, y_val, val_size=val_size, random_state=random_state
384384
)
385385

386-
self.model = TaskModel(
386+
self.task_model = TaskModel(
387387
model_class=self.base_model,
388388
num_classes=self.family.param_count,
389389
family=self.family,
@@ -419,12 +419,12 @@ def fit(
419419
],
420420
**trainer_kwargs
421421
)
422-
self.trainer.fit(self.model, self.data_module)
422+
self.trainer.fit(self.task_model, self.data_module)
423423

424424
best_model_path = checkpoint_callback.best_model_path
425425
if best_model_path:
426426
checkpoint = torch.load(best_model_path)
427-
self.model.load_state_dict(checkpoint["state_dict"])
427+
self.task_model.load_state_dict(checkpoint["state_dict"])
428428

429429
return self
430430

@@ -444,14 +444,14 @@ def predict(self, X, raw=False):
444444
The predicted target values.
445445
"""
446446
# Ensure model and data module are initialized
447-
if self.model is None or self.data_module is None:
447+
if self.task_model is None or self.data_module is None:
448448
raise ValueError("The model or data module has not been fitted yet.")
449449

450450
# Preprocess the data using the data module
451451
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
452452

453453
# Move tensors to appropriate device
454-
device = next(self.model.parameters()).device
454+
device = next(self.task_model.parameters()).device
455455
if isinstance(cat_tensors, list):
456456
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
457457
else:
@@ -463,14 +463,14 @@ def predict(self, X, raw=False):
463463
num_tensors = num_tensors.to(device)
464464

465465
# Set model to evaluation mode
466-
self.model.eval()
466+
self.task_model.eval()
467467

468468
# Perform inference
469469
with torch.no_grad():
470-
predictions = self.model(num_features=num_tensors, cat_features=cat_tensors)
470+
predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
471471

472472
if not raw:
473-
return self.model.family(predictions).cpu().numpy()
473+
return self.task_model.family(predictions).cpu().numpy()
474474

475475
# Convert predictions to NumPy array and return
476476
else:
@@ -506,7 +506,7 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None):
506506
"""
507507
# Infer distribution family from model settings if not provided
508508
if distribution_family is None:
509-
distribution_family = getattr(self.model, "distribution_family", "normal")
509+
distribution_family = getattr(self.task_model, "distribution_family", "normal")
510510

511511
# Setup default metrics if none are provided
512512
if metrics is None:

mambular/models/sklearn_base_regressor.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model, config, **kwargs):
3737

3838
self.preprocessor = Preprocessor(**preprocessor_kwargs)
3939
self.base_model = model
40-
self.model = None
40+
self.task_model = None
4141
self.built = False
4242

4343
# Raise a warning if task is set to 'classification'
@@ -190,7 +190,7 @@ def build_model(
190190
X, y, X_val, y_val, val_size=val_size, random_state=random_state
191191
)
192192

193-
self.model = TaskModel(
193+
self.task_model = TaskModel(
194194
model_class=self.base_model,
195195
config=self.config,
196196
cat_feature_info=self.data_module.cat_feature_info,
@@ -232,10 +232,10 @@ def get_number_of_params(self, requires_grad=True):
232232
else:
233233
if requires_grad:
234234
return sum(
235-
p.numel() for p in self.model.parameters() if p.requires_grad
235+
p.numel() for p in self.task_model.parameters() if p.requires_grad
236236
)
237237
else:
238-
return sum(p.numel() for p in self.model.parameters())
238+
return sum(p.numel() for p in self.task_model.parameters())
239239

240240
def fit(
241241
self,
@@ -336,7 +336,7 @@ def fit(
336336
X, y, X_val, y_val, val_size=val_size, random_state=random_state
337337
)
338338

339-
self.model = TaskModel(
339+
self.task_model = TaskModel(
340340
model_class=self.base_model,
341341
config=self.config,
342342
cat_feature_info=self.data_module.cat_feature_info,
@@ -372,12 +372,12 @@ def fit(
372372
],
373373
**trainer_kwargs
374374
)
375-
self.trainer.fit(self.model, self.data_module)
375+
self.trainer.fit(self.task_model, self.data_module)
376376

377377
best_model_path = checkpoint_callback.best_model_path
378378
if best_model_path:
379379
checkpoint = torch.load(best_model_path)
380-
self.model.load_state_dict(checkpoint["state_dict"])
380+
self.task_model.load_state_dict(checkpoint["state_dict"])
381381

382382
return self
383383

@@ -397,14 +397,14 @@ def predict(self, X):
397397
The predicted target values.
398398
"""
399399
# Ensure model and data module are initialized
400-
if self.model is None or self.data_module is None:
400+
if self.task_model is None or self.data_module is None:
401401
raise ValueError("The model or data module has not been fitted yet.")
402402

403403
# Preprocess the data using the data module
404404
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
405405

406406
# Move tensors to appropriate device
407-
device = next(self.model.parameters()).device
407+
device = next(self.task_model.parameters()).device
408408
if isinstance(cat_tensors, list):
409409
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
410410
else:
@@ -416,11 +416,13 @@ def predict(self, X):
416416
num_tensors = num_tensors.to(device)
417417

418418
# Set model to evaluation mode
419-
self.model.eval()
419+
self.task_model.eval()
420420

421421
# Perform inference
422422
with torch.no_grad():
423-
predictions = self.model(num_features=num_tensors, cat_features=cat_tensors)
423+
predictions = self.task_model(
424+
num_features=num_tensors, cat_features=cat_tensors
425+
)
424426

425427
# Convert predictions to NumPy array and return
426428
return predictions.cpu().numpy()

0 commit comments

Comments
 (0)