from torch.utils.data import Dataset class UserItemRatingDataset(Dataset): """Wrapper, convert Tensor into Pytorch Dataset""" def __init__(self, user_tensor, item_tensor, target_tensor): """ args: target_tensor: torch.Tensor, the corresponding rating for pair """ self.user_tensor = user_tensor self.item_tensor = item_tensor self.target_tensor = target_tensor def __getitem__(self, index): return self.user_tensor[index], self.item_tensor[index], self.target_tensor[index] def __len__(self): return self.user_tensor.size(0)