Skip to content

Commit 16e18f4

Browse files
committed
adding score function to lss models
1 parent a996e6e commit 16e18f4

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

mambular/models/sklearn_base_lss.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ def predict(self, X, raw=False):
467467

468468
# Perform inference
469469
with torch.no_grad():
470-
predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
470+
predictions = self.task_model(
471+
num_features=num_tensors, cat_features=cat_tensors
472+
)
471473

472474
if not raw:
473475
return self.task_model.family(predictions).cpu().numpy()
@@ -506,7 +508,9 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None):
506508
"""
507509
# Infer distribution family from model settings if not provided
508510
if distribution_family is None:
509-
distribution_family = getattr(self.task_model, "distribution_family", "normal")
511+
distribution_family = getattr(
512+
self.task_model, "distribution_family", "normal"
513+
)
510514

511515
# Setup default metrics if none are provided
512516
if metrics is None:
@@ -559,3 +563,25 @@ def get_default_metrics(self, distribution_family):
559563
"categorical": {"Accuracy": accuracy_score},
560564
}
561565
return default_metrics.get(distribution_family, {})
566+
567+
def score(self, X, y, metric="NLL"):
568+
"""
569+
Calculate the score of the model using the specified metric.
570+
571+
Parameters
572+
----------
573+
X : array-like or pd.DataFrame of shape (n_samples, n_features)
574+
The input samples to predict.
575+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
576+
The true target values against which to evaluate the predictions.
577+
metric : str, default="NLL"
578+
So far, only negative log-likelihood is supported
579+
580+
Returns
581+
-------
582+
score : float
583+
The score calculated using the specified metric.
584+
"""
585+
predictions = self.predict(X)
586+
score = self.task_model.family.evaluate_nll(y, predictions)
587+
return score

0 commit comments

Comments
 (0)