70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
import itertools
|
|
from typing import Union
|
|
import torch
|
|
|
|
from pygrex.utils.torch_utils import use_cuda
|
|
from .recommender_model import RecommenderModel
|
|
from pygrex.data_reader import DataReader
|
|
|
|
|
|
class PyTorchModel(RecommenderModel, torch.nn.Module):
|
|
"""Meta Learner
|
|
|
|
Note: Subclass should implement self.model !
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: float,
|
|
latent_dim: int,
|
|
epochs: int,
|
|
batch_size: int,
|
|
cuda: bool,
|
|
optimizer_name: str,
|
|
device_id: Union[int, None] = None,
|
|
):
|
|
if optimizer_name not in ["sgd", "adam", "rmsprop"]:
|
|
raise Exception("Wrong optimizer.")
|
|
|
|
if cuda is True and device_id is not None:
|
|
use_cuda(True, device_id)
|
|
|
|
self.latent_dim = latent_dim
|
|
self.learning_rate = learning_rate
|
|
self.epochs = epochs
|
|
self.batch_size = batch_size
|
|
self._cuda = cuda
|
|
self.optimizer_name = optimizer_name
|
|
|
|
self.dataset = None
|
|
self.dataset_metadata = None
|
|
self.embedding_user = None
|
|
self.embedding_item = None
|
|
self.optimizer = None
|
|
|
|
super().__init__()
|
|
|
|
def fit(self, data: DataReader):
|
|
pass
|
|
|
|
def predict(self, user_id, item_id) -> list:
|
|
if isinstance(user_id, int):
|
|
user_id = [user_id]
|
|
if isinstance(item_id, int):
|
|
item_id = [item_id]
|
|
user_id = torch.LongTensor(user_id)
|
|
item_id = torch.LongTensor(item_id)
|
|
with torch.no_grad():
|
|
if self._cuda:
|
|
user_id = user_id.cuda()
|
|
item_id = item_id.cuda()
|
|
pred = self.forward(user_id, item_id).cpu().tolist()
|
|
pred = list(itertools.chain.from_iterable(pred))
|
|
return pred
|
|
|
|
def user_embedding(self):
|
|
return self.state_dict()["embedding_user.weight"].cpu().numpy()
|
|
|
|
def item_embedding(self):
|
|
return self.state_dict()["embedding_item.weight"].cpu().numpy()
|