public code v1
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
import itertools
|
||||
from typing import Union
|
||||
import torch
|
||||
|
||||
from pygrex.utils.torch_utils import use_cuda
|
||||
from .recommender_model import RecommenderModel
|
||||
from pygrex.data_reader import DataReader
|
||||
|
||||
|
||||
class PyTorchModel(RecommenderModel, torch.nn.Module):
|
||||
"""Meta Learner
|
||||
|
||||
Note: Subclass should implement self.model !
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
latent_dim: int,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
cuda: bool,
|
||||
optimizer_name: str,
|
||||
device_id: Union[int, None] = None,
|
||||
):
|
||||
if optimizer_name not in ["sgd", "adam", "rmsprop"]:
|
||||
raise Exception("Wrong optimizer.")
|
||||
|
||||
if cuda is True and device_id is not None:
|
||||
use_cuda(True, device_id)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.learning_rate = learning_rate
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
self._cuda = cuda
|
||||
self.optimizer_name = optimizer_name
|
||||
|
||||
self.dataset = None
|
||||
self.dataset_metadata = None
|
||||
self.embedding_user = None
|
||||
self.embedding_item = None
|
||||
self.optimizer = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
def fit(self, data: DataReader):
|
||||
pass
|
||||
|
||||
def predict(self, user_id, item_id) -> list:
|
||||
if isinstance(user_id, int):
|
||||
user_id = [user_id]
|
||||
if isinstance(item_id, int):
|
||||
item_id = [item_id]
|
||||
user_id = torch.LongTensor(user_id)
|
||||
item_id = torch.LongTensor(item_id)
|
||||
with torch.no_grad():
|
||||
if self._cuda:
|
||||
user_id = user_id.cuda()
|
||||
item_id = item_id.cuda()
|
||||
pred = self.forward(user_id, item_id).cpu().tolist()
|
||||
pred = list(itertools.chain.from_iterable(pred))
|
||||
return pred
|
||||
|
||||
def user_embedding(self):
|
||||
return self.state_dict()["embedding_user.weight"].cpu().numpy()
|
||||
|
||||
def item_embedding(self):
|
||||
return self.state_dict()["embedding_item.weight"].cpu().numpy()
|
||||
Reference in New Issue
Block a user