public code v1
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader import DataReader
|
||||
from pygrex.models import ALS
|
||||
from pygrex.recommender import Recommender
|
||||
from pygrex.evaluator import ModelEvaluator, Splitter
|
||||
from pygrex.explain import KNNPostHocExplainer
|
||||
# from threadpoolctl import threadpool_limits
|
||||
|
||||
|
||||
# # Fixture to replace setUp method
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def limit_blas_threads():
|
||||
# threadpool_limits(1, "blas")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def setup_data():
|
||||
# Initialize ALS model
|
||||
als = ALS(**cfg.model.als)
|
||||
|
||||
# Prepare data
|
||||
data = DataReader(**cfg.data.ml100k)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
|
||||
return als, data
|
||||
|
||||
|
||||
def test_explain_als(setup_data):
|
||||
als, data = setup_data
|
||||
|
||||
# Split data
|
||||
sp = Splitter()
|
||||
train, test = sp.split_leave_n_out(data, n=1)
|
||||
|
||||
# Test ALS model fitting (fit should complete without raising)
|
||||
als.fit(train)
|
||||
|
||||
# Generate recommendations
|
||||
recommender = Recommender(train, als)
|
||||
# Recommend on a small, valid subset of items
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
|
||||
# Pipeline ran without exceptions; evaluation moved elsewhere
|
||||
|
||||
# explainer = ALSExplainer(als, recommendations, data)
|
||||
# explainer.explain_recommendations()
|
||||
|
||||
# KNN Post Hoc Explainer (skipped in this smoke test)
|
||||
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.als_model import ALS
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
als = ALS(**cfg.model.als)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
als.fit(data)
|
||||
return als, data
|
||||
|
||||
|
||||
def test_train_recommend_als(setup_data):
|
||||
als, data = setup_data
|
||||
recommender = Recommender(data, als)
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.autoencoder_model import ExplAutoencoderTorch
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
autoencoder = ExplAutoencoderTorch(**cfg.model.autoencoder)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return autoencoder, data
|
||||
|
||||
|
||||
def test_train_autoencoder(setup_data):
|
||||
autoencoder, data = setup_data
|
||||
autoencoder.fit(data)
|
||||
# Directly test predict on a valid single item after fit
|
||||
_ = autoencoder.predict(0, 0)
|
||||
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models import EMFModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
emf = EMFModel(**cfg.model.emf)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
return emf, data
|
||||
|
||||
|
||||
def test_train_emf(setup_data):
|
||||
emf, data = setup_data
|
||||
emf.fit(data)
|
||||
recommender = Recommender(data, emf)
|
||||
recommender.recommend_all()
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.gmf_model import GMFModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
gmf = GMFModel(**cfg.model.gmf)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return gmf, data
|
||||
|
||||
|
||||
def test_train_gmf(setup_data):
|
||||
gmf, data = setup_data
|
||||
gmf.fit(data)
|
||||
recommender = Recommender(data, gmf)
|
||||
recommender.recommend_all()
|
||||
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.als_model import ALS
|
||||
from pygrex.models.bpr_model import BPR
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
als = ALS(**cfg.model.als)
|
||||
bpr = BPR(**cfg.model.bpr)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return als, bpr, data
|
||||
|
||||
|
||||
def test_train_als(setup_data):
|
||||
als, _, data = setup_data
|
||||
als.fit(data)
|
||||
recommender = Recommender(data, als)
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
|
||||
|
||||
def test_train_bpr(setup_data):
|
||||
_, bpr, data = setup_data
|
||||
bpr.fit(data)
|
||||
recommender = Recommender(data, bpr)
|
||||
max_valid = bpr.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.mlp_model import MLPModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
mlp = MLPModel(**cfg.model.mlp)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return mlp, data
|
||||
|
||||
|
||||
def test_train_mlp(setup_data):
|
||||
mlp, data = setup_data
|
||||
mlp.fit(data)
|
||||
# Directly predict for a single valid item after fit
|
||||
_ = mlp.predict(0, 0)
|
||||
Reference in New Issue
Block a user