-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathsurprise_wrapper.py
More file actions
33 lines (28 loc) · 982 Bytes
/
surprise_wrapper.py
File metadata and controls
33 lines (28 loc) · 982 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from surprise import Dataset
from surprise import Reader
class SurpriseRecommender(BaseEstimator, RegressorMixin):
def __init__(self, rating_scale, model):
self.rating_scale = rating_scale
self.reader = Reader(rating_scale = rating_scale)
self.model = model
def fit(self, X, y):
df = pd.DataFrame(X)
df["rating"] = y
trainset = Dataset.load_from_df(df, self.reader).build_full_trainset()
self.model.fit(trainset)
return(self)
def predict(self, X):
df = pd.DataFrame(X)
df["rating"] = 0
prediction_df = pd.DataFrame(
self.model.test(
Dataset.load_from_df(df, self.reader)
.build_full_trainset()
.build_testset()
)
)
prediction_array = prediction_df["est"].values
return(prediction_array)