180 lines
6.3 KiB
Python
180 lines
6.3 KiB
Python
import random
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import Optimizer
|
|
from tqdm.auto import tqdm
|
|
|
|
from pygrex.data_reader import DataReader, UserItemRatingDataset
|
|
from pygrex.utils.torch_utils import use_optimizer
|
|
from .py_torch_model import PyTorchModel
|
|
|
|
|
|
class MLPModel(PyTorchModel):
|
|
def __init__(
|
|
self,
|
|
learning_rate: float,
|
|
weight_decay: float,
|
|
latent_dim: int,
|
|
epochs: int,
|
|
num_negative: int,
|
|
batch_size: int,
|
|
cuda: bool,
|
|
optimizer_name: str,
|
|
device_id=None,
|
|
):
|
|
super().__init__(
|
|
learning_rate=learning_rate,
|
|
latent_dim=latent_dim,
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
cuda=cuda,
|
|
optimizer_name=optimizer_name,
|
|
device_id=device_id,
|
|
)
|
|
|
|
self.negative_sample_size = num_negative
|
|
self.weight_decay = weight_decay
|
|
|
|
# layer dim is 2*self.latent_dim since the embeddings will be concatenated
|
|
self.affine_output = torch.nn.Linear(
|
|
in_features=2 * self.latent_dim, out_features=1
|
|
)
|
|
self.logistic = torch.nn.Sigmoid()
|
|
|
|
self.criterion = nn.BCELoss()
|
|
self.optimizer: Optimizer | None = None
|
|
|
|
def fit(self, data: DataReader):
|
|
optimizer = use_optimizer(
|
|
network=self,
|
|
weight_decay=self.weight_decay,
|
|
learning_rate=self.learning_rate,
|
|
optimizer_name=self.optimizer_name,
|
|
)
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(f"Expected an Optimizer, but got {type(optimizer)}")
|
|
self.optimizer = optimizer
|
|
|
|
dataset = data.dataset
|
|
|
|
num_users = data.num_user
|
|
num_items = data.num_item
|
|
|
|
self.embedding_user = torch.nn.Embedding(
|
|
num_embeddings=num_users, embedding_dim=self.latent_dim
|
|
)
|
|
|
|
self.embedding_item = torch.nn.Embedding(
|
|
num_embeddings=num_items, embedding_dim=self.latent_dim
|
|
)
|
|
|
|
self.negatives = self._sample_negative(dataset)
|
|
|
|
with tqdm(total=self.epochs) as progress:
|
|
for epoch in range(self.epochs):
|
|
train_loader = self.instance_a_train_loader(
|
|
dataset, self.negative_sample_size, self.batch_size
|
|
)
|
|
loss = self.train_an_epoch(train_loader)
|
|
progress.update(1)
|
|
progress.set_postfix({"loss": loss})
|
|
|
|
def instance_a_train_loader(self, dataset, num_negatives, batch_size):
|
|
"""instance train loader for one training epoch"""
|
|
users, items, ratings = [], [], []
|
|
train_ratings = pd.merge(
|
|
dataset, self.negatives[["userId", "negative_items"]], on="userId"
|
|
)
|
|
train_ratings["negatives"] = train_ratings["negative_items"].apply(
|
|
lambda x: random.sample(list(x), num_negatives)
|
|
)
|
|
user_ids = train_ratings["userId"].tolist()
|
|
item_ids = train_ratings["itemId"].tolist()
|
|
rating_values = train_ratings["rating"].tolist()
|
|
negatives_lists = train_ratings["negatives"].tolist()
|
|
|
|
for user, item, rating, negatives in zip(
|
|
user_ids, item_ids, rating_values, negatives_lists
|
|
):
|
|
users.append(user)
|
|
items.append(item)
|
|
ratings.append(rating)
|
|
for neg_item in negatives:
|
|
users.append(user)
|
|
items.append(neg_item)
|
|
ratings.append(float(0)) # negative samples get 0 rating
|
|
|
|
dataset = UserItemRatingDataset(
|
|
user_tensor=torch.LongTensor(users),
|
|
item_tensor=torch.LongTensor(items),
|
|
target_tensor=torch.FloatTensor(ratings),
|
|
)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
def train_an_epoch(self, train_loader):
|
|
self.train()
|
|
cnt = 0
|
|
total_loss = 0
|
|
for batch_id, batch in enumerate(train_loader):
|
|
assert isinstance(batch[0], torch.LongTensor)
|
|
user, item, rating = batch[0], batch[1], batch[2]
|
|
rating = rating.float()
|
|
loss = self.train_single_batch(user, item, rating)
|
|
total_loss += loss
|
|
cnt += 1
|
|
return total_loss / cnt
|
|
|
|
def train_single_batch(self, users, items, ratings):
|
|
if self.cuda is True:
|
|
users, items, ratings = users.cuda(), items.cuda(), ratings.cuda()
|
|
if self.optimizer is None:
|
|
raise RuntimeError(
|
|
"Optimizer is not initialized. Call fit() before training."
|
|
)
|
|
|
|
self.optimizer.zero_grad()
|
|
ratings_pred = self(users, items)
|
|
loss = self.criterion(ratings_pred.view(-1), ratings)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
loss = loss.item()
|
|
return loss
|
|
|
|
def _sample_negative(self, ratings):
|
|
"""return all negative items & 100 sampled negative items"""
|
|
interact_status = (
|
|
ratings.groupby("userId")["itemId"]
|
|
.apply(set)
|
|
.reset_index()
|
|
.rename(columns={"itemId": "interacted_items"})
|
|
)
|
|
self.item_catalogue = set(ratings.itemId)
|
|
interact_status["negative_items"] = interact_status["interacted_items"].apply(
|
|
lambda x: self.item_catalogue - x
|
|
)
|
|
return interact_status[["userId", "negative_items"]]
|
|
|
|
def forward(self, user_indices, item_indices):
|
|
user_embedding = self.embedding_user(user_indices)
|
|
item_embedding = self.embedding_item(item_indices)
|
|
|
|
# Ensure embeddings are 2D [batch_size, embedding_dim]
|
|
if user_embedding.dim() == 3:
|
|
user_embedding = user_embedding.squeeze(1)
|
|
if item_embedding.dim() == 3:
|
|
item_embedding = item_embedding.squeeze(1)
|
|
|
|
# This is needed because cat does not support broadcasting.
|
|
if user_embedding.size(0) == 1 and item_embedding.size(0) > 1:
|
|
user_embedding = user_embedding.repeat(item_embedding.size(0), 1)
|
|
elif item_embedding.size(0) == 1 and user_embedding.size(0) > 1:
|
|
item_embedding = item_embedding.repeat(user_embedding.size(0), 1)
|
|
|
|
element_concat = torch.cat((user_embedding, item_embedding), 1)
|
|
concat = self.affine_output(element_concat)
|
|
rating = self.logistic(concat)
|
|
return rating
|