Files
py-grex/pygrex/recommender/generic_recommender.py
T
2026-05-22 10:02:10 +02:00

73 lines
2.4 KiB
Python

import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm
class GenericRecommender:
def __init__(self, dataset_metadata, model, top_n: int = 10):
self.top_n = top_n
self.dataset = dataset_metadata.dataset
self.model = model
self.catalogue = set(self.dataset["itemId"])
def recommend_all(self):
"""
Get all recommendations.
:param top_n:
:return: recommendations for any user.
"""
ratings = self.dataset.groupby("userId")
recommendations = pd.DataFrame({"userId": [], "itemId": [], "rank": []})
with tqdm(
total=self.dataset["userId"].nunique(), desc="Recommending for users: "
) as pbar:
for user_id, user_ratings in ratings:
# Replace .append() with pd.concat() - pandas 2.2.x +
recommendations = pd.concat(
[recommendations, self.recommend_user(user_id, user_ratings)], # type: ignore
ignore_index=True,
)
pbar.update()
return recommendations
def rank_prediction(self, user_id, target_item_id, predictions):
# Ensure predictions are flattened if they're 2D
if isinstance(predictions, np.ndarray) and predictions.ndim > 1:
predictions = predictions.flatten()
recommendations = pd.DataFrame(
{"userId": user_id, "itemId": target_item_id, "prediction": predictions}
)
recommendations["rank"] = recommendations["prediction"].rank(
method="first", ascending=False
)
recommendations.sort_values(["userId", "rank"], inplace=True)
recommendations = recommendations[recommendations["rank"] <= self.top_n]
return recommendations[["userId", "itemId", "rank"]]
def get_unrated(self, user_ratings):
"""
Extract the set of items a user has not rated.
:param user_ratings: list, items rated.
:return: list, items not rated.
"""
unrated_item_id = self.catalogue - set(user_ratings)
unrated_item_id = list(unrated_item_id)
return unrated_item_id
def get_rated(self, user_id):
"""
Extract the set of items a user has not rated.
:param user_id: userId rated.
:return: list, rated items.
"""
rated = self.dataset[self.dataset["userId"] == user_id]
return rated