23 lines
633 B
Python
23 lines
633 B
Python
import torch.nn as nn
|
|
|
|
|
|
class Item2Vec(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.num_items = config['num_items']
|
|
self.latent_dim = config['latent_dim']
|
|
self.embedding = nn.Embedding(
|
|
num_embeddings=self.num_items,
|
|
embedding_dim=self.latent_dim)
|
|
self.fc = nn.Linear(
|
|
in_features=self.latent_dim,
|
|
out_features=self.num_items)
|
|
|
|
def forward(self, input_data):
|
|
embedding = self.embedding(input_data)
|
|
return self.fc(embedding)
|
|
|
|
def item_embedding(self):
|
|
return self.embedding.weight.detach()
|