Files
2026-05-22 10:02:10 +02:00

23 lines
633 B
Python

import torch.nn as nn
class Item2Vec(nn.Module):
def __init__(self, config):
super().__init__()
self.num_items = config['num_items']
self.latent_dim = config['latent_dim']
self.embedding = nn.Embedding(
num_embeddings=self.num_items,
embedding_dim=self.latent_dim)
self.fc = nn.Linear(
in_features=self.latent_dim,
out_features=self.num_items)
def forward(self, input_data):
embedding = self.embedding(input_data)
return self.fc(embedding)
def item_embedding(self):
return self.embedding.weight.detach()