public code v1
This commit is contained in:
@@ -0,0 +1,179 @@
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import Optimizer
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from pygrex.data_reader import DataReader, UserItemRatingDataset
|
||||
from pygrex.utils.torch_utils import use_optimizer
|
||||
from .py_torch_model import PyTorchModel
|
||||
|
||||
|
||||
class MLPModel(PyTorchModel):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
weight_decay: float,
|
||||
latent_dim: int,
|
||||
epochs: int,
|
||||
num_negative: int,
|
||||
batch_size: int,
|
||||
cuda: bool,
|
||||
optimizer_name: str,
|
||||
device_id=None,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
latent_dim=latent_dim,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
cuda=cuda,
|
||||
optimizer_name=optimizer_name,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
self.negative_sample_size = num_negative
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
# layer dim is 2*self.latent_dim since the embeddings will be concatenated
|
||||
self.affine_output = torch.nn.Linear(
|
||||
in_features=2 * self.latent_dim, out_features=1
|
||||
)
|
||||
self.logistic = torch.nn.Sigmoid()
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
self.optimizer: Optimizer | None = None
|
||||
|
||||
def fit(self, data: DataReader):
|
||||
optimizer = use_optimizer(
|
||||
network=self,
|
||||
weight_decay=self.weight_decay,
|
||||
learning_rate=self.learning_rate,
|
||||
optimizer_name=self.optimizer_name,
|
||||
)
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError(f"Expected an Optimizer, but got {type(optimizer)}")
|
||||
self.optimizer = optimizer
|
||||
|
||||
dataset = data.dataset
|
||||
|
||||
num_users = data.num_user
|
||||
num_items = data.num_item
|
||||
|
||||
self.embedding_user = torch.nn.Embedding(
|
||||
num_embeddings=num_users, embedding_dim=self.latent_dim
|
||||
)
|
||||
|
||||
self.embedding_item = torch.nn.Embedding(
|
||||
num_embeddings=num_items, embedding_dim=self.latent_dim
|
||||
)
|
||||
|
||||
self.negatives = self._sample_negative(dataset)
|
||||
|
||||
with tqdm(total=self.epochs) as progress:
|
||||
for epoch in range(self.epochs):
|
||||
train_loader = self.instance_a_train_loader(
|
||||
dataset, self.negative_sample_size, self.batch_size
|
||||
)
|
||||
loss = self.train_an_epoch(train_loader)
|
||||
progress.update(1)
|
||||
progress.set_postfix({"loss": loss})
|
||||
|
||||
def instance_a_train_loader(self, dataset, num_negatives, batch_size):
|
||||
"""instance train loader for one training epoch"""
|
||||
users, items, ratings = [], [], []
|
||||
train_ratings = pd.merge(
|
||||
dataset, self.negatives[["userId", "negative_items"]], on="userId"
|
||||
)
|
||||
train_ratings["negatives"] = train_ratings["negative_items"].apply(
|
||||
lambda x: random.sample(list(x), num_negatives)
|
||||
)
|
||||
user_ids = train_ratings["userId"].tolist()
|
||||
item_ids = train_ratings["itemId"].tolist()
|
||||
rating_values = train_ratings["rating"].tolist()
|
||||
negatives_lists = train_ratings["negatives"].tolist()
|
||||
|
||||
for user, item, rating, negatives in zip(
|
||||
user_ids, item_ids, rating_values, negatives_lists
|
||||
):
|
||||
users.append(user)
|
||||
items.append(item)
|
||||
ratings.append(rating)
|
||||
for neg_item in negatives:
|
||||
users.append(user)
|
||||
items.append(neg_item)
|
||||
ratings.append(float(0)) # negative samples get 0 rating
|
||||
|
||||
dataset = UserItemRatingDataset(
|
||||
user_tensor=torch.LongTensor(users),
|
||||
item_tensor=torch.LongTensor(items),
|
||||
target_tensor=torch.FloatTensor(ratings),
|
||||
)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
def train_an_epoch(self, train_loader):
|
||||
self.train()
|
||||
cnt = 0
|
||||
total_loss = 0
|
||||
for batch_id, batch in enumerate(train_loader):
|
||||
assert isinstance(batch[0], torch.LongTensor)
|
||||
user, item, rating = batch[0], batch[1], batch[2]
|
||||
rating = rating.float()
|
||||
loss = self.train_single_batch(user, item, rating)
|
||||
total_loss += loss
|
||||
cnt += 1
|
||||
return total_loss / cnt
|
||||
|
||||
def train_single_batch(self, users, items, ratings):
|
||||
if self.cuda is True:
|
||||
users, items, ratings = users.cuda(), items.cuda(), ratings.cuda()
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError(
|
||||
"Optimizer is not initialized. Call fit() before training."
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
ratings_pred = self(users, items)
|
||||
loss = self.criterion(ratings_pred.view(-1), ratings)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
loss = loss.item()
|
||||
return loss
|
||||
|
||||
def _sample_negative(self, ratings):
|
||||
"""return all negative items & 100 sampled negative items"""
|
||||
interact_status = (
|
||||
ratings.groupby("userId")["itemId"]
|
||||
.apply(set)
|
||||
.reset_index()
|
||||
.rename(columns={"itemId": "interacted_items"})
|
||||
)
|
||||
self.item_catalogue = set(ratings.itemId)
|
||||
interact_status["negative_items"] = interact_status["interacted_items"].apply(
|
||||
lambda x: self.item_catalogue - x
|
||||
)
|
||||
return interact_status[["userId", "negative_items"]]
|
||||
|
||||
def forward(self, user_indices, item_indices):
|
||||
user_embedding = self.embedding_user(user_indices)
|
||||
item_embedding = self.embedding_item(item_indices)
|
||||
|
||||
# Ensure embeddings are 2D [batch_size, embedding_dim]
|
||||
if user_embedding.dim() == 3:
|
||||
user_embedding = user_embedding.squeeze(1)
|
||||
if item_embedding.dim() == 3:
|
||||
item_embedding = item_embedding.squeeze(1)
|
||||
|
||||
# This is needed because cat does not support broadcasting.
|
||||
if user_embedding.size(0) == 1 and item_embedding.size(0) > 1:
|
||||
user_embedding = user_embedding.repeat(item_embedding.size(0), 1)
|
||||
elif item_embedding.size(0) == 1 and user_embedding.size(0) > 1:
|
||||
item_embedding = item_embedding.repeat(user_embedding.size(0), 1)
|
||||
|
||||
element_concat = torch.cat((user_embedding, item_embedding), 1)
|
||||
concat = self.affine_output(element_concat)
|
||||
rating = self.logistic(concat)
|
||||
return rating
|
||||
Reference in New Issue
Block a user