public code v1
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Item2Vec(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_items = config['num_items']
|
||||
self.latent_dim = config['latent_dim']
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=self.num_items,
|
||||
embedding_dim=self.latent_dim)
|
||||
self.fc = nn.Linear(
|
||||
in_features=self.latent_dim,
|
||||
out_features=self.num_items)
|
||||
|
||||
def forward(self, input_data):
|
||||
embedding = self.embedding(input_data)
|
||||
return self.fc(embedding)
|
||||
|
||||
def item_embedding(self):
|
||||
return self.embedding.weight.detach()
|
||||
Reference in New Issue
Block a user