public code v1
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_df():
|
||||
"""Fixture providing a valid test DataFrame."""
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3, 1],
|
||||
"itemId": [100, 200, 300, 400],
|
||||
"rating": [4.5, 3.0, 5.0, 2.0],
|
||||
"timestamp": [1000, 2000, 3000, 4000],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_csv_path(valid_df):
|
||||
"""Fixture providing a temporary CSV file with valid data."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_csv_path = os.path.join(temp_dir, "test_data.csv")
|
||||
valid_df.to_csv(test_csv_path, sep=",", index=False)
|
||||
yield test_csv_path
|
||||
|
||||
|
||||
def test_init_with_dataframe(valid_df):
|
||||
"""Test initialization with a DataFrame."""
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
pd.testing.assert_frame_equal(reader.dataset, valid_df)
|
||||
assert reader.num_user == 3
|
||||
assert reader.num_item == 4
|
||||
|
||||
|
||||
def test_init_with_filepath(test_csv_path):
|
||||
"""Test initialization with filepath."""
|
||||
reader = DataReader(
|
||||
filepath_or_buffer=test_csv_path,
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
skiprows=1, # Skip header
|
||||
)
|
||||
# Check that dataset was loaded correctly
|
||||
assert reader.num_user == 3
|
||||
assert reader.num_item == 4
|
||||
|
||||
|
||||
def test_dataset_loading_file_not_found():
|
||||
"""Test exception when file is not found."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataReader(
|
||||
filepath_or_buffer="nonexistent.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_dataset_validation_invalid_columns():
|
||||
"""Test validation for invalid columns."""
|
||||
# Create a DataFrame with non-numeric userId
|
||||
invalid_df = pd.DataFrame(
|
||||
{
|
||||
"userId": ["user1", "user2", "user3"],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": [4.5, 3.0, 5.0],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=invalid_df)
|
||||
assert "cannot be converted to a numeric type" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_empty():
|
||||
"""Test validation for empty DataFrame."""
|
||||
empty_df = pd.DataFrame(columns=["userId", "itemId", "rating", "timestamp"])
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=empty_df)
|
||||
assert "DataFrame cannot be empty" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_invalid_types():
|
||||
"""Test validation for invalid data types."""
|
||||
# Create DataFrame with invalid rating (string)
|
||||
invalid_types_df = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": ["high", "medium", "low"],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=invalid_types_df)
|
||||
assert "cannot be converted to a numeric type" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_missing_values():
|
||||
"""Test validation for missing values."""
|
||||
# Create DataFrame with NaN values
|
||||
missing_values_df = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, np.nan],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": [4.5, 3.0, 5.0],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=missing_values_df)
|
||||
assert "contains missing values" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_raw_dataset_not_set():
|
||||
"""Test get_raw_dataset when not set."""
|
||||
# Mock pandas.read_csv to raise FileNotFoundError
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_make_consecutive_ids_in_dataset(valid_df):
|
||||
"""Test making IDs consecutive."""
|
||||
# Create a DataFrame with non-consecutive IDs
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"userId": [100, 200, 300, 100],
|
||||
"itemId": [500, 600, 700, 800],
|
||||
"rating": [4.5, 3.0, 5.0, 2.0],
|
||||
"timestamp": [1000, 2000, 3000, 4000],
|
||||
}
|
||||
)
|
||||
|
||||
reader = DataReader(dataframe=df)
|
||||
reader.make_consecutive_ids_in_dataset()
|
||||
|
||||
# Check that IDs are now consecutive
|
||||
assert set(reader.dataset["userId"]) == {0, 1, 2}
|
||||
assert set(reader.dataset["itemId"]) == {0, 1, 2, 3}
|
||||
|
||||
# Check mapping
|
||||
assert reader.get_original_user_id(0) == 100
|
||||
assert reader.get_original_item_id(0) == 500
|
||||
|
||||
|
||||
def test_make_consecutive_ids_not_set():
|
||||
"""Test make_consecutive_ids_in_dataset when dataset is not set."""
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_binarize(valid_df):
|
||||
"""Test binarization of ratings."""
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
reader.binarize(binary_threshold=3.5)
|
||||
|
||||
# Check that ratings are binarized
|
||||
expected = [1, 0, 1, 0] # 4.5 > 3.5, 3.0 <= 3.5, 5.0 > 3.5, 2.0 <= 3.5
|
||||
assert list(reader.dataset["rating"]) == expected
|
||||
|
||||
|
||||
def test_binarize_not_set():
|
||||
"""Test binarize when dataset is not set."""
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_get_original_user_id_not_set(valid_df):
|
||||
"""Test get_original_user_id when mapping is not set."""
|
||||
# Create reader without calling make_consecutive_ids_in_dataset
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
reader.get_original_user_id(0)
|
||||
assert "ID mapping not set" in str(excinfo.value)
|
||||
@@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_reader():
|
||||
return DataReader(**cfg.data.testdata)
|
||||
|
||||
|
||||
def test_import(data_reader):
|
||||
assert data_reader.num_user == 249
|
||||
assert data_reader.num_item == 551
|
||||
assert data_reader.dataset.shape[0] == 1000
|
||||
assert data_reader.dataset.shape[1] == 4
|
||||
@@ -0,0 +1,266 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
from unittest.mock import patch
|
||||
|
||||
from pygrex.data_reader import DataReader, GroupInteractionHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_environment():
|
||||
"""Set up test environment before each test."""
|
||||
# Create a temporary directory for test files
|
||||
test_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Create test group files
|
||||
group1_file = test_dir / "group1.txt"
|
||||
group2_file = test_dir / "group2.txt"
|
||||
group1_file.write_text("1_2_3\n4_5_6\n7_8_9")
|
||||
group2_file.write_text("10_11_12\n13_14_15")
|
||||
|
||||
# Create test data for DataReader mock
|
||||
test_data = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3, 4, 4, 5],
|
||||
"itemId": [101, 102, 102, 103, 101, 104, 105, 106, 107],
|
||||
"rating": [5.0, 4.0, 3.0, 5.0, 4.0, 3.0, 5.0, 4.0, 3.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize the handler
|
||||
handler = GroupInteractionHandler(test_dir)
|
||||
|
||||
yield {
|
||||
"test_dir": test_dir,
|
||||
"group1_file": group1_file,
|
||||
"group2_file": group2_file,
|
||||
"test_data": test_data,
|
||||
"handler": handler,
|
||||
}
|
||||
|
||||
# Clean up test environment after each test
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
|
||||
def test_init_with_directory(test_environment):
|
||||
"""Test initialization with a directory path."""
|
||||
handler = GroupInteractionHandler(test_environment["test_dir"])
|
||||
assert len(handler.filepath_or_buffer) == 2
|
||||
file_paths = set(handler.filepath_or_buffer)
|
||||
assert str(test_environment["group1_file"]) in file_paths
|
||||
assert str(test_environment["group2_file"]) in file_paths
|
||||
|
||||
|
||||
def test_init_with_file_list(test_environment):
|
||||
"""Test initialization with a list of file paths."""
|
||||
file_list = [test_environment["group1_file"], test_environment["group2_file"]]
|
||||
handler = GroupInteractionHandler(file_list)
|
||||
assert len(handler.filepath_or_buffer) == 2
|
||||
assert str(test_environment["group1_file"]) in handler.filepath_or_buffer
|
||||
assert str(test_environment["group2_file"]) in handler.filepath_or_buffer
|
||||
|
||||
|
||||
def test_init_with_single_file(test_environment):
|
||||
"""Test initialization with a single file path."""
|
||||
handler = GroupInteractionHandler(test_environment["group1_file"])
|
||||
assert len(handler.filepath_or_buffer) == 1
|
||||
assert handler.filepath_or_buffer[0] == str(test_environment["group1_file"])
|
||||
|
||||
|
||||
def test_get_group_filepath(test_environment):
|
||||
"""Test retrieving a group file path."""
|
||||
handler = test_environment["handler"]
|
||||
path = handler._get_group_filepath("group1")
|
||||
assert path == str(test_environment["group1_file"].resolve())
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
handler._get_group_filepath("nonexistent")
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Error: No file found containing 'nonexistent' in its name."
|
||||
)
|
||||
|
||||
|
||||
def test_get_group_filepath_non_existent_file(test_environment):
|
||||
"""Test _get_group_filepath with a non-existent file."""
|
||||
non_existent_file = test_environment["test_dir"] / "missing.txt"
|
||||
handler = GroupInteractionHandler([non_existent_file])
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
handler._get_group_filepath("missing")
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== f"Error: File does not exist: {non_existent_file.resolve()}"
|
||||
)
|
||||
|
||||
|
||||
def test_read_groups(test_environment):
|
||||
"""Test reading groups from a file."""
|
||||
handler = test_environment["handler"]
|
||||
groups = handler.read_groups("group1")
|
||||
assert len(groups) == 3
|
||||
assert groups == ["1_2_3", "4_5_6", "7_8_9"]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
handler.read_groups("")
|
||||
assert str(exc_info.value) == "Groups path not specified in configuration"
|
||||
|
||||
|
||||
def test_parse_group_members(test_environment):
|
||||
"""Test parsing group members from a group ID string."""
|
||||
handler = test_environment["handler"]
|
||||
members = handler.parse_group_members("1_2_3")
|
||||
assert members == [1, 2, 3]
|
||||
|
||||
members = handler.parse_group_members(" 4_5_6 ")
|
||||
assert members == [4, 5, 6]
|
||||
|
||||
|
||||
def test_get_group_members(test_environment):
|
||||
"""Test getting group members from string or list."""
|
||||
handler = test_environment["handler"]
|
||||
# Test with string
|
||||
members = handler.get_group_members("1_2_3")
|
||||
assert members == [1, 2, 3]
|
||||
|
||||
# Test with list
|
||||
members = handler.get_group_members([1, 2, 3])
|
||||
assert members == [1, 2, 3]
|
||||
|
||||
# Test empty string
|
||||
members = handler.get_group_members("")
|
||||
assert members == []
|
||||
|
||||
# Test invalid string
|
||||
with pytest.raises(ValueError):
|
||||
handler.get_group_members("1_a_3")
|
||||
|
||||
# Test invalid type
|
||||
with pytest.raises(TypeError):
|
||||
handler.get_group_members(123)
|
||||
|
||||
|
||||
@patch("pygrex.data_reader.DataReader", spec=DataReader)
|
||||
def test_create_modified_dataset(mock_data_reader, test_environment):
|
||||
"""Test creating a modified dataset."""
|
||||
mock_reader = mock_data_reader()
|
||||
mock_reader.dataset = test_environment["test_data"].copy()
|
||||
mock_reader.get_new_user_id.side_effect = lambda x: x
|
||||
mock_reader.get_new_item_id.side_effect = lambda x: x
|
||||
|
||||
handler = test_environment["handler"]
|
||||
# Test with DataReader as original_data
|
||||
result_df = handler.create_modified_dataset(mock_reader, [1, 2], [101, 102])
|
||||
assert len(result_df) == 6
|
||||
assert not (
|
||||
(result_df.userId.isin([1, 2])) & (result_df.itemId.isin([101, 102]))
|
||||
).any()
|
||||
|
||||
# Test with DataFrame and DataReader
|
||||
result_df = handler.create_modified_dataset(
|
||||
test_environment["test_data"], [1, 2], [101, 102], mock_reader
|
||||
)
|
||||
assert len(result_df) == 6
|
||||
assert not (
|
||||
(result_df.userId.isin([1, 2])) & (result_df.itemId.isin([101, 102]))
|
||||
).any()
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(ValueError):
|
||||
handler.create_modified_dataset(
|
||||
test_environment["test_data"], [1, 2], [101, 102]
|
||||
)
|
||||
|
||||
|
||||
@patch("pygrex.data_reader.DataReader", spec=DataReader)
|
||||
def test_get_rated_items_by_all_group_members(mock_data_reader, test_environment):
|
||||
"""Test getting items rated by any group member."""
|
||||
mock_reader = mock_data_reader()
|
||||
mock_reader.dataset = test_environment["test_data"].copy()
|
||||
mock_reader.get_new_user_id.side_effect = lambda x: x
|
||||
mock_reader.get_original_item_id.side_effect = lambda x: x
|
||||
|
||||
handler = test_environment["handler"]
|
||||
group = [1, 2, 3]
|
||||
rated_items = handler.get_rated_items_by_all_group_members(group, mock_reader)
|
||||
expected_items = np.array([101, 102, 103, 104])
|
||||
np.testing.assert_array_equal(np.sort(rated_items), np.sort(expected_items))
|
||||
|
||||
|
||||
@patch("pygrex.data_reader.DataReader", spec=DataReader)
|
||||
def test_get_common_rated_items(mock_data_reader, test_environment):
|
||||
"""Test getting items rated by all group members."""
|
||||
mock_reader = mock_data_reader()
|
||||
mock_reader.dataset = test_environment["test_data"].copy()
|
||||
mock_reader.get_new_user_id.side_effect = lambda x: x
|
||||
mock_reader.get_original_item_id.side_effect = lambda x: x
|
||||
|
||||
handler = test_environment["handler"]
|
||||
# Test group with common item
|
||||
group = [1, 2]
|
||||
common_items = handler.get_common_rated_items(group, mock_reader)
|
||||
np.testing.assert_array_equal(common_items, np.array([102]))
|
||||
|
||||
# Test group with no common items
|
||||
group = [1, 4]
|
||||
common_items = handler.get_common_rated_items(group, mock_reader)
|
||||
assert len(common_items) == 0
|
||||
|
||||
# Test empty group
|
||||
group = []
|
||||
common_items = handler.get_common_rated_items(group, mock_reader)
|
||||
assert len(common_items) == 0
|
||||
np.testing.assert_array_equal(common_items, np.array([]))
|
||||
|
||||
|
||||
def test_get_items_for_group_recommendation(test_environment):
|
||||
"""Test getting items for group recommendation."""
|
||||
handler = test_environment["handler"]
|
||||
data = pd.DataFrame(
|
||||
{"userId": [1, 1, 2, 2, 3], "itemId": [101, 102, 103, 104, 105]}
|
||||
)
|
||||
all_items = np.array([101, 102, 103, 104, 105, 106, 107])
|
||||
group = [1, 2]
|
||||
|
||||
result = handler.get_items_for_group_recommendation(data, all_items, group)
|
||||
expected = np.array([105, 106, 107])
|
||||
np.testing.assert_array_equal(np.sort(result), np.sort(expected))
|
||||
|
||||
|
||||
@patch("pygrex.data_reader.DataReader", spec=DataReader)
|
||||
def test_get_group_preferences(mock_data_reader, test_environment):
|
||||
"""Test getting preferences for group members."""
|
||||
mock_reader = mock_data_reader()
|
||||
mock_reader.dataset = test_environment["test_data"].copy()
|
||||
mock_reader.get_new_user_id.side_effect = lambda x: x
|
||||
|
||||
handler = test_environment["handler"]
|
||||
group = [1, 2]
|
||||
prefs = handler.get_group_preferences(group, mock_reader)
|
||||
assert len(prefs) == 4
|
||||
assert prefs.userId.isin([1, 2]).all()
|
||||
|
||||
|
||||
@patch("pygrex.data_reader.DataReader", spec=DataReader)
|
||||
def test_get_group_preferences_edge_cases(mock_data_reader, test_environment):
|
||||
"""Test get_group_preferences with string IDs and empty group."""
|
||||
mock_reader = mock_data_reader()
|
||||
mock_reader.dataset = test_environment["test_data"].copy()
|
||||
mock_reader.get_new_user_id.side_effect = (
|
||||
lambda x: int(x) if isinstance(x, str) else x
|
||||
)
|
||||
|
||||
handler = test_environment["handler"]
|
||||
# Test with string IDs
|
||||
group = ["1", "2"]
|
||||
prefs = handler.get_group_preferences(group, mock_reader)
|
||||
assert len(prefs) == 4
|
||||
assert prefs.userId.isin([1, 2]).all()
|
||||
|
||||
# Test with empty group
|
||||
group = []
|
||||
prefs = handler.get_group_preferences(group, mock_reader)
|
||||
assert len(prefs) == 0
|
||||
assert prefs.empty
|
||||
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.data_reader.user_item_dict import UserItemDict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
return data
|
||||
|
||||
|
||||
def test_user_item_matrix(setup_data):
|
||||
data = setup_data
|
||||
|
||||
# Get dimensions for expl_matrix
|
||||
n_users = data.dataset.userId.nunique()
|
||||
n_items = data.dataset.itemId.nunique()
|
||||
|
||||
# Create a zero-filled expl_matrix
|
||||
expl_matrix = torch.zeros((n_users, n_items))
|
||||
|
||||
# Set expl to False for this test
|
||||
expl = False
|
||||
|
||||
# Initialize UserItemDict with all required parameters
|
||||
user_dict = UserItemDict(data.dataset, expl_matrix, expl)
|
||||
|
||||
# Test that the first rating is correctly stored
|
||||
x = data.dataset.userId[0]
|
||||
y = data.dataset.itemId[0]
|
||||
v = data.dataset.rating[0]
|
||||
assert user_dict[x][y] == v
|
||||
@@ -0,0 +1,256 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Import the class to be tested
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.utils.sliding_window_ranker import SlidingWindowRanker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_reader():
|
||||
"""Create a mock DataReader instance for testing."""
|
||||
mock_reader = MagicMock(spec=DataReader)
|
||||
|
||||
# Sample dataset with user-item interactions
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3, 4],
|
||||
"itemId": [101, 102, 101, 103, 102, 104, 105],
|
||||
"rating": [4.0, 3.5, 5.0, 2.0, 4.5, 3.0, 3.8],
|
||||
"timestamp": [1700000000, 1700600000, 1701200000, 1701800000, 1702400000, 1703000000, 1703600000],
|
||||
}
|
||||
)
|
||||
|
||||
# Set up the mock to return the sample dataset
|
||||
mock_reader.dataset = dataset
|
||||
|
||||
# Set up ID mapping methods
|
||||
mock_reader.get_new_user_id = lambda user_id: user_id
|
||||
mock_reader.get_new_item_id = lambda item_id: item_id
|
||||
|
||||
return mock_reader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def evaluator():
|
||||
"""Create a SlidingWindowRanker instance for testing."""
|
||||
config = {"test_param": "test_value"}
|
||||
return SlidingWindowRanker(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def group_predictions():
|
||||
"""Sample prediction data for multiple users."""
|
||||
return {
|
||||
1: {101: 4.2, 102: 3.7, 103: 2.5, 104: 1.8},
|
||||
2: {101: 4.8, 102: 3.1, 103: 2.2},
|
||||
3: {101: 3.9, 102: 4.3, 104: 3.2},
|
||||
4: {101: 3.5, 105: 4.0},
|
||||
}
|
||||
|
||||
|
||||
class TestSlidingWindowEvaluator:
|
||||
def test_initialization(self, evaluator):
|
||||
"""Test that the evaluator initializes with the correct configuration."""
|
||||
assert evaluator.config == {"test_param": "test_value"}
|
||||
assert evaluator.group_predictions is None
|
||||
assert evaluator.top_recommendation is None
|
||||
|
||||
def test_set_group_recommender_values(self, evaluator, group_predictions):
|
||||
"""Test setting group recommender values."""
|
||||
evaluator.set_group_recommender_values(group_predictions, 101)
|
||||
|
||||
assert evaluator.group_predictions == group_predictions
|
||||
assert evaluator.top_recommendation == 101
|
||||
|
||||
def test_calculate_item_popularity_score(self, evaluator, mock_data_reader):
|
||||
"""Test calculating item popularity scores."""
|
||||
items = [101, 102, 103, 104, 105]
|
||||
|
||||
# Expected counts based on mock dataset:
|
||||
# item 101: 2 interactions, item 102: 2 interactions, item 103: 1 interaction,
|
||||
# item 104: 1 interaction, item 105: 1 interaction
|
||||
popularity_scores = evaluator.calculate_item_popularity_score(
|
||||
items, mock_data_reader
|
||||
)
|
||||
|
||||
# Verify that more popular items have higher scores
|
||||
assert (
|
||||
popularity_scores[101] == popularity_scores[102]
|
||||
) # Both have 2 interactions
|
||||
assert (
|
||||
popularity_scores[101] > popularity_scores[103]
|
||||
) # 2 interactions > 1 interaction
|
||||
|
||||
# Check that all items have scores between 0 and 1
|
||||
for item_id, score in popularity_scores.items():
|
||||
assert 0 <= score <= 1
|
||||
|
||||
def test_calculate_relevance_mask_with_predictions(
|
||||
self, evaluator, group_predictions
|
||||
):
|
||||
"""Test calculating relevance mask when predictions are available."""
|
||||
evaluator.set_group_recommender_values(group_predictions, 101)
|
||||
|
||||
# Test for item that all users have predictions for
|
||||
relevance_mask = evaluator.calculate_relevance_mask(101)
|
||||
assert relevance_mask == {1: 4.2, 2: 4.8, 3: 3.9, 4: 3.5}
|
||||
|
||||
# Test for item that some users don't have predictions for
|
||||
relevance_mask = evaluator.calculate_relevance_mask(105)
|
||||
assert relevance_mask == {1: 0, 2: 0, 3: 0, 4: 4.0}
|
||||
|
||||
# Test for item no user has predictions for
|
||||
relevance_mask = evaluator.calculate_relevance_mask(999)
|
||||
assert relevance_mask == {1: 0, 2: 0, 3: 0, 4: 0}
|
||||
|
||||
def test_calculate_relevance_mask_without_predictions(self, evaluator):
|
||||
"""Test calculating relevance mask when predictions are not set."""
|
||||
with pytest.raises(ValueError, match="User predictions not set"):
|
||||
evaluator.calculate_relevance_mask(101)
|
||||
|
||||
def test_calculate_relevance_score(
|
||||
self, evaluator, mock_data_reader, group_predictions
|
||||
):
|
||||
"""Test calculating relevance score for an item."""
|
||||
evaluator.set_group_recommender_values(group_predictions, 101)
|
||||
prediction_scores = {1: 4.2, 2: 4.8, 3: 3.9, 4: 3.5}
|
||||
members = [1, 2, 3, 4]
|
||||
|
||||
# Test for item with good data
|
||||
relevance_score = evaluator.calculate_relevance_score(
|
||||
101, mock_data_reader, prediction_scores, members
|
||||
)
|
||||
assert 0 <= relevance_score <= 1
|
||||
|
||||
# Test with empty members list
|
||||
relevance_score = evaluator.calculate_relevance_score(
|
||||
101, mock_data_reader, prediction_scores, []
|
||||
)
|
||||
assert relevance_score == 0
|
||||
|
||||
# Test with no valid users (no one has interacted with the item)
|
||||
relevance_score = evaluator.calculate_relevance_score(
|
||||
999, mock_data_reader, prediction_scores, members
|
||||
)
|
||||
assert relevance_score == 0
|
||||
|
||||
def test_calculate_item_intensity_score(self, evaluator, mock_data_reader):
|
||||
"""Test calculating item intensity scores."""
|
||||
# Test with normal group
|
||||
members = [1, 2, 3, 4]
|
||||
|
||||
# Item 101 has been interacted with by users 1 and 2 (2/4 = 0.5)
|
||||
intensity = evaluator.calculate_item_intensity_score(
|
||||
101, members, mock_data_reader
|
||||
)
|
||||
assert intensity == 0.5
|
||||
|
||||
# Item 102 has been interacted with by users 1 and 3 (2/4 = 0.5)
|
||||
intensity = evaluator.calculate_item_intensity_score(
|
||||
102, members, mock_data_reader
|
||||
)
|
||||
assert intensity == 0.5
|
||||
|
||||
# Item 105 has been interacted with by user 4 only (1/4 = 0.25)
|
||||
intensity = evaluator.calculate_item_intensity_score(
|
||||
105, members, mock_data_reader
|
||||
)
|
||||
assert intensity == 0.25
|
||||
|
||||
# Test with empty members list
|
||||
intensity = evaluator.calculate_item_intensity_score(101, [], mock_data_reader)
|
||||
assert intensity == 0
|
||||
|
||||
def test_calculate_rating_score(self, evaluator, mock_data_reader):
|
||||
"""Test calculating rating scores."""
|
||||
members = [1, 2, 3, 4]
|
||||
|
||||
# Item 101 has ratings from users 1 (4.0) and 2 (5.0)
|
||||
# Average over all members: (4.0 + 5.0) / 4 = 2.25
|
||||
rating_score = evaluator.calculate_rating_score(101, members, mock_data_reader)
|
||||
assert 0 <= rating_score <= 1
|
||||
|
||||
# Test with empty members list
|
||||
rating_score = evaluator.calculate_rating_score(101, [], mock_data_reader)
|
||||
assert rating_score == 0
|
||||
|
||||
def test_generate_ranked_items(
|
||||
self, evaluator, mock_data_reader, group_predictions
|
||||
):
|
||||
"""Test generating ranked items based on various scores."""
|
||||
evaluator.set_group_recommender_values(group_predictions, 101)
|
||||
all_rated_items = [101, 102, 103, 104, 105]
|
||||
members = [1, 2, 3, 4]
|
||||
|
||||
# Test with default weights
|
||||
ranked_items, metrics = evaluator.generate_ranked_items(
|
||||
all_rated_items, mock_data_reader, members
|
||||
)
|
||||
assert isinstance(ranked_items, list)
|
||||
assert isinstance(metrics, dict)
|
||||
assert len(ranked_items) == len(all_rated_items)
|
||||
assert set(ranked_items) == set(all_rated_items)
|
||||
|
||||
# Test with custom weights
|
||||
custom_weights = {
|
||||
"popularity": 2.0,
|
||||
"intensity": 0.5,
|
||||
"rating": 1.0,
|
||||
"relevance": 1.5,
|
||||
"trend": 0.0,
|
||||
}
|
||||
ranked_items_custom, metrics_custom = evaluator.generate_ranked_items(
|
||||
all_rated_items, mock_data_reader, members, custom_weights
|
||||
)
|
||||
assert isinstance(ranked_items_custom, list)
|
||||
assert isinstance(metrics_custom, dict)
|
||||
assert len(ranked_items_custom) == len(all_rated_items)
|
||||
|
||||
# Test without group predictions set
|
||||
evaluator.group_predictions = None
|
||||
with pytest.raises(ValueError, match="User predictions not set"):
|
||||
evaluator.generate_ranked_items(all_rated_items, mock_data_reader, members)
|
||||
|
||||
def test_evaluate_not_implemented(self, evaluator, mock_data_reader):
|
||||
"""Test that the evaluate method is defined but not implemented."""
|
||||
# The evaluate method returns an empty dict placeholder
|
||||
result = evaluator.evaluate(mock_data_reader)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# Additional tests for edge cases
|
||||
|
||||
|
||||
def test_with_numpy_user_ids(evaluator, mock_data_reader):
|
||||
"""Test handling of numpy integer user IDs."""
|
||||
# Set up a group with numpy integer user IDs
|
||||
np_members = [np.int64(1), np.int64(2), np.int64(3)]
|
||||
|
||||
# Should not raise an error and handle numpy integers correctly
|
||||
intensity = evaluator.calculate_item_intensity_score(
|
||||
101, np_members, mock_data_reader
|
||||
)
|
||||
assert 0 <= intensity <= 1
|
||||
|
||||
rating_score = evaluator.calculate_rating_score(101, np_members, mock_data_reader)
|
||||
assert 0 <= rating_score <= 1
|
||||
|
||||
|
||||
def test_with_different_rating_scale(evaluator, mock_data_reader):
|
||||
"""Test using a different rating scale for normalization."""
|
||||
members = [1, 2, 3, 4]
|
||||
custom_scale = (1, 10) # 1-10 rating scale
|
||||
|
||||
# Calculate score with custom rating scale
|
||||
rating_score = evaluator.calculate_rating_score(
|
||||
101, members, mock_data_reader, rating_scale=custom_scale
|
||||
)
|
||||
assert 0 <= rating_score <= 1
|
||||
|
||||
relevance_score = evaluator.calculate_relevance_score(
|
||||
101, mock_data_reader, {1: 8, 2: 9}, members, rating_scale=custom_scale
|
||||
)
|
||||
assert 0 <= relevance_score <= 1
|
||||
@@ -0,0 +1,431 @@
|
||||
"""Test suite for RuleBasedGroupRecExplainer."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, MagicMock
|
||||
import logging
|
||||
|
||||
from pygrex.explain.groups.rule_based_group_rec_explainer import RuleBasedGroupRecExplainer
|
||||
|
||||
|
||||
"""Test cases for RuleBasedGroupRecExplainer class."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_reader():
|
||||
"""Create a mock DataReader."""
|
||||
data_reader = Mock()
|
||||
data_reader.get_new_item_id.side_effect = lambda x: x # Return the same ID
|
||||
return data_reader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_rules():
|
||||
"""Create sample association rules DataFrame."""
|
||||
rules_data = {
|
||||
"antecedents": [
|
||||
{"item1", "item2"},
|
||||
{"item3"},
|
||||
{"item1"},
|
||||
{"item4", "item5"},
|
||||
],
|
||||
"consequents": [{"rec1"}, {"rec2"}, {"rec1"}, {"rec3"}],
|
||||
"confidence": [0.8, 0.9, 0.7, 0.6],
|
||||
"support": [0.3, 0.4, 0.2, 0.1],
|
||||
}
|
||||
return pd.DataFrame(rules_data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_history():
|
||||
"""Create sample user history."""
|
||||
return {
|
||||
"user1": {"item1", "item2", "item6"},
|
||||
"user2": {"item3", "item7"},
|
||||
"user3": {"item1", "item4", "item5"},
|
||||
"user4": {"item2", "item8"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def explainer(sample_rules, mock_data_reader, sample_user_history):
|
||||
"""Create a RuleBasedGroupRecExplainer instance."""
|
||||
return RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1", "rec2", "rec3"],
|
||||
members=["user1", "user2", "user3"],
|
||||
user_history=sample_user_history,
|
||||
min_members_threshold=2,
|
||||
)
|
||||
|
||||
|
||||
class TestInitialization:
|
||||
"""Test initialization and parameter validation."""
|
||||
|
||||
def test_init_with_valid_parameters(self, sample_rules, mock_data_reader):
|
||||
"""Test successful initialization with valid parameters."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1", "rec2"],
|
||||
members=["user1", "user2"],
|
||||
user_history={"user1": {"item1"}},
|
||||
min_members_threshold=1,
|
||||
)
|
||||
|
||||
assert explainer.rules is sample_rules
|
||||
assert explainer.data is mock_data_reader
|
||||
assert explainer.pool_recommendations == ["rec1", "rec2"]
|
||||
assert explainer.members == ["user1", "user2"]
|
||||
assert explainer.user_history == {"user1": {"item1"}}
|
||||
assert explainer.min_members_threshold == 1
|
||||
|
||||
def test_init_with_invalid_threshold(self, sample_rules, mock_data_reader):
|
||||
"""Test initialization with invalid min_members_threshold."""
|
||||
with pytest.raises(
|
||||
ValueError, match="min_members_threshold must be at least 1"
|
||||
):
|
||||
RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules, data=mock_data_reader, min_members_threshold=0
|
||||
)
|
||||
|
||||
def test_init_with_defaults(self, sample_rules, mock_data_reader):
|
||||
"""Test initialization with default parameters."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules, data=mock_data_reader
|
||||
)
|
||||
|
||||
assert explainer.pool_recommendations == []
|
||||
assert explainer.members == []
|
||||
assert explainer.user_history == {}
|
||||
assert explainer.min_members_threshold == 1
|
||||
|
||||
def test_normalize_recommendations_single_item(
|
||||
self, sample_rules, mock_data_reader
|
||||
):
|
||||
"""Test normalization of single recommendation item."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules, data=mock_data_reader, pool_recommendations="rec1"
|
||||
)
|
||||
assert explainer.pool_recommendations == ["rec1"]
|
||||
|
||||
def test_normalize_recommendations_list(self, sample_rules, mock_data_reader):
|
||||
"""Test normalization of recommendation list."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1", "rec2"],
|
||||
)
|
||||
assert explainer.pool_recommendations == ["rec1", "rec2"]
|
||||
|
||||
def test_normalize_recommendations_none(self, sample_rules, mock_data_reader):
|
||||
"""Test normalization of None recommendations."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules, data=mock_data_reader, pool_recommendations=None
|
||||
)
|
||||
assert explainer.pool_recommendations == []
|
||||
|
||||
|
||||
class TestPrivateMethods:
|
||||
"""Test private methods."""
|
||||
|
||||
def test_is_rule_satisfied_by_member_true(self, explainer):
|
||||
"""Test rule satisfaction when member has all antecedent items."""
|
||||
# user1 has item1 and item2
|
||||
result = explainer._is_rule_satisfied_by_member("user1", {"item1", "item2"})
|
||||
assert result is True
|
||||
|
||||
def test_is_rule_satisfied_by_member_false(self, explainer):
|
||||
"""Test rule satisfaction when member doesn't have all antecedent items."""
|
||||
# user2 only has item3, not item1 and item2
|
||||
result = explainer._is_rule_satisfied_by_member("user2", {"item1", "item2"})
|
||||
assert result is False
|
||||
|
||||
def test_is_rule_satisfied_by_member_empty_history(self, explainer):
|
||||
"""Test rule satisfaction for member with empty history."""
|
||||
result = explainer._is_rule_satisfied_by_member("nonexistent_user", {"item1"})
|
||||
assert result is False
|
||||
|
||||
def test_count_satisfied_members(self, explainer):
|
||||
"""Test counting members who satisfy antecedent."""
|
||||
# Only user1 and user3 have item1
|
||||
count = explainer._count_satisfied_members({"item1"})
|
||||
assert count == 2
|
||||
|
||||
def test_count_satisfied_members_complex_antecedent(self, explainer):
|
||||
"""Test counting with complex antecedent."""
|
||||
# Only user1 has both item1 and item2
|
||||
count = explainer._count_satisfied_members({"item1", "item2"})
|
||||
assert count == 1
|
||||
|
||||
def test_find_applicable_rules(self, explainer):
|
||||
"""Test finding rules applicable to an item."""
|
||||
applicable_rules = explainer._find_applicable_rules("rec1")
|
||||
|
||||
# Should find rules where 'rec1' is in consequents
|
||||
assert len(applicable_rules) == 2 # Rules with rec1 as consequent
|
||||
assert all(
|
||||
"rec1" in rule["consequents"] for _, rule in applicable_rules.iterrows()
|
||||
)
|
||||
|
||||
def test_can_explain_item_true(self, explainer):
|
||||
"""Test item explanation when rules are satisfied."""
|
||||
# rec1 should be explainable (rules with item1/item2 and item1 antecedents)
|
||||
result = explainer._can_explain_item("rec1")
|
||||
assert result is True
|
||||
|
||||
def test_can_explain_item_false(self, explainer):
|
||||
"""Test item explanation when no rules are satisfied."""
|
||||
# Create explainer with high threshold
|
||||
explainer.min_members_threshold = 5
|
||||
result = explainer._can_explain_item("rec1")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAdvancedMethods:
|
||||
"""Test advanced explanation methods."""
|
||||
|
||||
def test_member_has_antecedent_item_true(self, explainer):
|
||||
"""Test when member has at least one antecedent item."""
|
||||
# user1 has item1 (from antecedent {item1, item2})
|
||||
result = explainer._member_has_antecedent_item("user1", {"item1", "item9"})
|
||||
assert result is True
|
||||
|
||||
def test_member_has_antecedent_item_false(self, explainer):
|
||||
"""Test when member has no antecedent items."""
|
||||
# user2 doesn't have item1 or item9
|
||||
result = explainer._member_has_antecedent_item("user2", {"item1", "item9"})
|
||||
assert result is False
|
||||
|
||||
def test_can_explain_item_advanced_both_conditions_true(self, explainer):
|
||||
"""Test advanced explanation when both conditions are met."""
|
||||
members_set = {"user1", "user3"} # Both have item1
|
||||
all_seen_items = {"item1", "item2", "item4", "item5", "item6"}
|
||||
|
||||
# Rule: {item1} -> {rec1}
|
||||
# Condition 1: Both users have item1 ✓
|
||||
# Condition 2: item1 is in all_seen_items ✓
|
||||
result = explainer._can_explain_item_advanced(
|
||||
"rec1", members_set, all_seen_items
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_can_explain_item_advanced_condition1_false(self, explainer):
|
||||
"""Test advanced explanation when condition 1 fails."""
|
||||
members_set = {"user1", "user2"} # user2 doesn't have item1
|
||||
all_seen_items = {"item1", "item2", "item3"}
|
||||
|
||||
# For rule {item1} -> {rec1}: user2 doesn't have item1
|
||||
result = explainer._can_explain_item_advanced(
|
||||
"rec1", members_set, all_seen_items
|
||||
)
|
||||
# This might still be True if other rules apply, so let's check specific case
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_can_explain_item_advanced_condition2_false(self, explainer):
|
||||
"""Test advanced explanation when condition 2 fails."""
|
||||
members_set = {"user1", "user3"}
|
||||
all_seen_items = {"item6", "item7"} # Missing antecedent items
|
||||
|
||||
result = explainer._can_explain_item_advanced(
|
||||
"rec1", members_set, all_seen_items
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestPublicMethods:
|
||||
"""Test public methods."""
|
||||
|
||||
def test_find_explanation_with_recommendations(self, explainer):
|
||||
"""Test explanation finding with recommendations."""
|
||||
fidelity = explainer.find_explanation()
|
||||
|
||||
assert isinstance(fidelity, float)
|
||||
assert 0.0 <= fidelity <= 1.0
|
||||
|
||||
def test_find_explanation_empty_recommendations(
|
||||
self, sample_rules, mock_data_reader
|
||||
):
|
||||
"""Test explanation finding with empty recommendations."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules, data=mock_data_reader, pool_recommendations=[]
|
||||
)
|
||||
|
||||
fidelity = explainer.find_explanation()
|
||||
assert fidelity == 0.0
|
||||
|
||||
def test_compute_group_fidelity_advanced_with_data(self, explainer):
|
||||
"""Test advanced group fidelity computation."""
|
||||
fidelity = explainer.compute_group_fidelity_advanced()
|
||||
|
||||
assert isinstance(fidelity, float)
|
||||
assert 0.0 <= fidelity <= 1.0
|
||||
|
||||
def test_compute_group_fidelity_advanced_empty_recommendations(
|
||||
self, sample_rules, mock_data_reader
|
||||
):
|
||||
"""Test advanced fidelity with empty recommendations."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=[],
|
||||
members=["user1", "user2"],
|
||||
)
|
||||
|
||||
fidelity = explainer.compute_group_fidelity_advanced()
|
||||
assert fidelity == 0.0
|
||||
|
||||
def test_compute_group_fidelity_advanced_no_members(
|
||||
self, sample_rules, mock_data_reader
|
||||
):
|
||||
"""Test advanced fidelity with no members."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1"],
|
||||
members=[],
|
||||
)
|
||||
|
||||
fidelity = explainer.compute_group_fidelity_advanced()
|
||||
assert fidelity == 0.0
|
||||
|
||||
def test_get_explanation_details(self, explainer):
|
||||
"""Test getting detailed explanations."""
|
||||
details = explainer.get_explanation_details()
|
||||
|
||||
assert isinstance(details, dict)
|
||||
assert all(
|
||||
item_id in explainer.pool_recommendations for item_id in details.keys()
|
||||
)
|
||||
|
||||
# Check structure of explanation details
|
||||
for item_id, explanations in details.items():
|
||||
assert isinstance(explanations, list)
|
||||
for explanation in explanations:
|
||||
assert "antecedent" in explanation
|
||||
assert "consequent" in explanation
|
||||
assert "satisfied_members" in explanation
|
||||
assert "confidence" in explanation
|
||||
assert "support" in explanation
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_with_string_and_int_ids(self, sample_rules, mock_data_reader):
|
||||
"""Test handling of mixed string and integer IDs."""
|
||||
user_history = {
|
||||
1: {10, 20}, # Integer user ID with integer item IDs
|
||||
"user2": {"item1", "item2"}, # String user ID with string item IDs
|
||||
}
|
||||
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=[1, "rec1"],
|
||||
members=[1, "user2"],
|
||||
user_history=user_history,
|
||||
)
|
||||
|
||||
# Should not raise exceptions
|
||||
fidelity = explainer.find_explanation()
|
||||
assert isinstance(fidelity, float)
|
||||
|
||||
def test_empty_user_history(self, sample_rules, mock_data_reader):
|
||||
"""Test with completely empty user history."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1"],
|
||||
members=["user1", "user2"],
|
||||
user_history={},
|
||||
)
|
||||
|
||||
fidelity = explainer.find_explanation()
|
||||
assert fidelity == 0.0
|
||||
|
||||
def test_rules_with_no_matches(self, mock_data_reader):
|
||||
"""Test with rules that don't match any recommendations."""
|
||||
rules_data = {
|
||||
"antecedents": [{"item1"}],
|
||||
"consequents": [{"other_rec"}],
|
||||
"confidence": [0.8],
|
||||
"support": [0.3],
|
||||
}
|
||||
rules = pd.DataFrame(rules_data)
|
||||
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=rules, # type: ignore
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1"],
|
||||
members=["user1"],
|
||||
user_history={"user1": {"item1"}},
|
||||
)
|
||||
|
||||
fidelity = explainer.find_explanation()
|
||||
assert fidelity == 0.0
|
||||
|
||||
|
||||
class TestLogging:
|
||||
"""Test logging functionality."""
|
||||
|
||||
def test_logging_warnings(self, sample_rules, mock_data_reader, caplog):
|
||||
"""Test that appropriate warnings are logged."""
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=[],
|
||||
members=[],
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
explainer.find_explanation()
|
||||
|
||||
assert "No recommendations to explain" in caplog.text
|
||||
|
||||
def test_logging_info(self, explainer, caplog):
|
||||
"""Test that info messages are logged."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
explainer.find_explanation()
|
||||
|
||||
assert "Explained" in caplog.text
|
||||
assert "fidelity" in caplog.text
|
||||
|
||||
|
||||
# Integration tests
|
||||
class TestIntegration:
|
||||
"""Integration tests combining multiple components."""
|
||||
|
||||
def test_full_workflow(self, sample_rules, mock_data_reader):
|
||||
"""Test complete workflow from initialization to explanation."""
|
||||
user_history = {
|
||||
"user1": {"item1", "item2"},
|
||||
"user2": {"item3"},
|
||||
"user3": {"item1", "item4", "item5"},
|
||||
}
|
||||
|
||||
explainer = RuleBasedGroupRecExplainer(
|
||||
rules=sample_rules,
|
||||
data=mock_data_reader,
|
||||
pool_recommendations=["rec1", "rec2", "rec3"],
|
||||
members=["user1", "user2", "user3"],
|
||||
user_history=user_history, # type: ignore
|
||||
min_members_threshold=1,
|
||||
)
|
||||
|
||||
# Test basic explanation
|
||||
basic_fidelity = explainer.find_explanation()
|
||||
assert isinstance(basic_fidelity, float)
|
||||
|
||||
# Test advanced explanation
|
||||
advanced_fidelity = explainer.compute_group_fidelity_advanced()
|
||||
assert isinstance(advanced_fidelity, float)
|
||||
|
||||
# Test detailed explanations
|
||||
details = explainer.get_explanation_details()
|
||||
assert isinstance(details, dict)
|
||||
|
||||
# Fidelities should be between 0 and 1
|
||||
assert 0.0 <= basic_fidelity <= 1.0
|
||||
assert 0.0 <= advanced_fidelity <= 1.0
|
||||
@@ -0,0 +1,392 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from typing import List, Union
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.data_reader.group_interaction_handler import GroupInteractionHandler
|
||||
from pygrex.explain.groups.sliding_window_explainer import SlidingWindowExplainer
|
||||
from pygrex.models.recommender_model import RecommenderModel
|
||||
from pygrex.recommender.group_recommender import GroupRecommender
|
||||
|
||||
|
||||
class TestSlidingWindowExplainer:
|
||||
"""Test suite for the SlidingWindowExplainer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create a mock configuration object."""
|
||||
mock_cfg = Mock(spec=cfg)
|
||||
return mock_cfg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_reader(self):
|
||||
"""Create a mock DataReader with sample data."""
|
||||
# Create a sample dataset with user-item interactions
|
||||
data = {
|
||||
"userId": [1, 1, 1, 2, 2, 3, 3, 3],
|
||||
"itemId": [101, 102, 103, 101, 104, 102, 103, 105],
|
||||
"rating": [5, 4, 3, 4, 5, 3, 4, 5],
|
||||
"timestamp": [1700000000, 1700100000, 1700200000, 1700300000, 1700400000, 1700500000, 1700600000, 1700700000],
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Create a mock DataReader with the sample dataset
|
||||
mock_reader = Mock(spec=DataReader)
|
||||
mock_reader.dataset = df
|
||||
|
||||
# Mock ID mapping methods
|
||||
mock_reader.get_new_user_id = lambda x: int(
|
||||
x
|
||||
) # Just return the ID as is for testing
|
||||
mock_reader.get_new_item_id = lambda x: int(
|
||||
x
|
||||
) # Just return the ID as is for testing
|
||||
mock_reader.make_consecutive_ids_in_dataset = Mock()
|
||||
mock_reader.binarize = Mock()
|
||||
|
||||
return mock_reader
|
||||
|
||||
@pytest.fixture
|
||||
def mock_group_handler(self):
|
||||
"""Create a mock GroupInteractionHandler."""
|
||||
mock_handler = Mock(spec=GroupInteractionHandler)
|
||||
|
||||
# Mock create_modified_dataset to return modified DataFrame
|
||||
def modified_dataset_func(*args, **kwargs):
|
||||
# Create a slightly modified version of the original dataset
|
||||
modified_df = kwargs.get("original_data").copy()
|
||||
return modified_df
|
||||
|
||||
mock_handler.create_modified_dataset = modified_dataset_func
|
||||
return mock_handler
|
||||
|
||||
@pytest.fixture
|
||||
def mock_recommender_model(self):
|
||||
"""Create a mock recommender model."""
|
||||
mock_model = Mock(spec=RecommenderModel)
|
||||
mock_model.fit = Mock(return_value=None)
|
||||
return mock_model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sliding_window(self):
|
||||
"""Create a mock sliding window."""
|
||||
mock_window = Mock()
|
||||
|
||||
# Set up the get_next_window method to return windows in sequence
|
||||
mock_window.get_next_window = Mock(
|
||||
side_effect=[
|
||||
[101, 102], # First window
|
||||
[103, 104], # Second window
|
||||
None, # End of windows
|
||||
]
|
||||
)
|
||||
|
||||
return mock_window
|
||||
|
||||
@pytest.fixture
|
||||
def mock_group_recommender(self):
|
||||
"""Create a mock group recommender."""
|
||||
with patch(
|
||||
"pygrex.recommender.group_recommender.GroupRecommender", autospec=True
|
||||
) as mock_gr:
|
||||
# Configure the mock to return predictable recommendations
|
||||
instance = mock_gr.return_value
|
||||
instance.setup_recommendation = Mock()
|
||||
|
||||
# Make get_group_recommendations return different values based on input
|
||||
def get_recommendations(n):
|
||||
# Default recommendation includes target item 200
|
||||
return [200, 201, 202]
|
||||
|
||||
instance.get_group_recommendations = Mock(side_effect=get_recommendations)
|
||||
yield mock_gr
|
||||
|
||||
@pytest.fixture
|
||||
def explainer(
|
||||
self,
|
||||
mock_config,
|
||||
mock_data_reader,
|
||||
mock_group_handler,
|
||||
mock_recommender_model,
|
||||
mock_sliding_window,
|
||||
):
|
||||
"""Create a SlidingWindowExplainer instance with mocked dependencies."""
|
||||
# Create an explainer with test data
|
||||
explainer = SlidingWindowExplainer(
|
||||
config=mock_config,
|
||||
data=mock_data_reader,
|
||||
group_handler=mock_group_handler,
|
||||
members=[1, 2, 3],
|
||||
target_item=200,
|
||||
model=mock_recommender_model,
|
||||
)
|
||||
return explainer
|
||||
|
||||
def test_initialization(self, explainer):
|
||||
"""Test that the explainer initializes with correct attributes."""
|
||||
assert explainer.members == [1, 2, 3]
|
||||
assert explainer.target_item == 200
|
||||
# candidate_items no longer kept on the explainer API
|
||||
assert explainer.calls == 0
|
||||
assert explainer.explanations_found == {}
|
||||
|
||||
def test_set_sliding_window(self, explainer):
|
||||
"""Test setting the sliding window after initialization."""
|
||||
new_window = Mock()
|
||||
explainer.set_sliding_window(new_window)
|
||||
assert explainer.sliding_window == new_window
|
||||
|
||||
def test_find_explanation_no_sliding_window(self, explainer):
|
||||
pytest.skip("find_explanation no longer depends on a preset sliding window")
|
||||
|
||||
@patch.object(SlidingWindowExplainer, "_test_window_removal")
|
||||
@patch.object(SlidingWindowExplainer, "_find_minimal_subset")
|
||||
def test_find_explanation_no_explanations_found(
|
||||
self, mock_find_minimal, mock_test_window, explainer
|
||||
):
|
||||
"""Test behavior when no explanations are found."""
|
||||
# Make _test_window_removal always return False (no effect on recommendations)
|
||||
mock_test_window.return_value = False
|
||||
|
||||
# Prepare minimal valid inputs for new API
|
||||
items_rated_by_group = [101, 102, 103, 104]
|
||||
group_predictions = {1: {101: 4.0}, 2: {102: 3.5}, 3: {103: 4.2}}
|
||||
top_recommendation = 200
|
||||
ranking_weights = {"popularity": 1, "intensity": 1, "rating": 1, "relevance": 1, "trend": 0}
|
||||
result = explainer.find_explanation(items_rated_by_group, group_predictions, top_recommendation, ranking_weights)
|
||||
|
||||
# Check that window was tested but no minimal subset was searched
|
||||
assert mock_test_window.call_count > 0
|
||||
assert mock_find_minimal.call_count == 0
|
||||
assert result == {} # No explanations found
|
||||
|
||||
@patch.object(SlidingWindowExplainer, "_test_window_removal")
|
||||
@patch.object(SlidingWindowExplainer, "_find_minimal_subset")
|
||||
def test_find_explanation_found(
|
||||
self, mock_find_minimal, mock_test_window, explainer
|
||||
):
|
||||
"""Test behavior when an explanation is found."""
|
||||
# Make second window test return True (has effect on recommendations)
|
||||
mock_test_window.side_effect = [False, True]
|
||||
|
||||
items_rated_by_group = [101, 102, 103, 104]
|
||||
group_predictions = {1: {101: 4.0}, 2: {102: 3.5}, 3: {103: 4.2}}
|
||||
top_recommendation = 200
|
||||
ranking_weights = {"popularity": 1, "intensity": 1, "rating": 1, "relevance": 1, "trend": 0}
|
||||
explainer.find_explanation(items_rated_by_group, group_predictions, top_recommendation, ranking_weights)
|
||||
|
||||
# Check that minimal subset was searched for the second window
|
||||
assert mock_test_window.call_count == 2
|
||||
assert mock_find_minimal.call_count == 1
|
||||
# Check that the window passed contains the later items
|
||||
passed_window = mock_find_minimal.call_args[0][0]
|
||||
assert 103 in passed_window and 104 in passed_window
|
||||
|
||||
@patch("pygrex.explain.groups.sliding_window_explainer.GroupRecommender")
|
||||
def test_get_recommendations_after_removal(
|
||||
self,
|
||||
mock_group_recommender_cls,
|
||||
explainer,
|
||||
mock_data_reader,
|
||||
):
|
||||
"""Test getting recommendations after removing items."""
|
||||
|
||||
# Create a mock GroupRecommender instance
|
||||
mock_recommender_instance = Mock()
|
||||
mock_group_recommender_cls.return_value = mock_recommender_instance
|
||||
|
||||
# Mock methods on the GroupRecommender instance
|
||||
mock_recommender_instance.get_group_recommendations.return_value = [
|
||||
201,
|
||||
202,
|
||||
203,
|
||||
]
|
||||
mock_recommender_instance.setup_recommendation = (
|
||||
Mock()
|
||||
) # Mock the method itself
|
||||
|
||||
# Mock the internal methods on the ACTUAL 'explainer' (SlidingWindowExplainer instance)
|
||||
# 'explainer' here is now the fixture instance, as intended.
|
||||
explainer._create_data_reader_and_prepare = Mock(return_value=mock_data_reader)
|
||||
explainer._retrain_model = Mock(
|
||||
return_value=Mock()
|
||||
) # Assuming _retrain_model returns a model mock
|
||||
explainer.group_handler.create_modified_dataset = Mock(
|
||||
return_value=mock_data_reader.dataset
|
||||
)
|
||||
|
||||
# Test the method on the actual 'explainer' instance
|
||||
result = explainer._get_recommendations_after_removal([101, 102])
|
||||
|
||||
# Verify the result
|
||||
# 'result' should now be [201, 202, 203] because it comes from
|
||||
# mock_recommender_instance.get_group_recommendations
|
||||
assert result == [201, 202, 203]
|
||||
|
||||
# Verify that the GroupRecommender CLASS was instantiated with the correct data reader
|
||||
# Inside _get_recommendations_after_removal:
|
||||
# data_retrained = self._create_data_reader_and_prepare(...) # returns mock_data_reader
|
||||
# group_recommender = GroupRecommender(data_retrained) # This is the call we're checking
|
||||
mock_group_recommender_cls.assert_called_once_with(mock_data_reader)
|
||||
|
||||
# Verify that setup_recommendation was called ON THE INSTANCE
|
||||
mock_recommender_instance.setup_recommendation.assert_called_once()
|
||||
|
||||
# Verify that get_group_recommendations was called ON THE INSTANCE with default top_n=10
|
||||
mock_recommender_instance.get_group_recommendations.assert_called_once_with(10)
|
||||
|
||||
def test_test_window_removal_target_removed(self, explainer):
|
||||
"""Test that _test_window_removal returns True when target item is removed from recommendations."""
|
||||
# Mock _get_recommendations_after_removal to return recommendations without target item
|
||||
explainer._get_recommendations_after_removal = Mock(
|
||||
return_value=[201, 202, 203]
|
||||
)
|
||||
|
||||
result = explainer._test_window_removal([101, 102], 200)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_test_window_removal_target_still_present(self, explainer):
|
||||
"""Test that _test_window_removal returns False when target item remains in recommendations."""
|
||||
# Mock _get_recommendations_after_removal to return recommendations with target item
|
||||
explainer._get_recommendations_after_removal = Mock(
|
||||
return_value=[200, 201, 202]
|
||||
)
|
||||
|
||||
result = explainer._test_window_removal([103, 104], 200)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch.object(SlidingWindowExplainer, "_get_recommendations_after_removal")
|
||||
@patch.object(SlidingWindowExplainer, "_record_explanation")
|
||||
def test_find_minimal_subset_found(self, mock_record, mock_get_recs, explainer):
|
||||
"""Test finding a minimal subset that produces a counterfactual explanation."""
|
||||
|
||||
# Configure mock to make only [101] affect recommendations (not include target 200)
|
||||
def get_recs_side_effect(items, top_n=10):
|
||||
if items == [101]:
|
||||
return [201, 202, 203] # Without target item
|
||||
else:
|
||||
return [200, 201, 202] # With target item
|
||||
|
||||
mock_get_recs.side_effect = get_recs_side_effect
|
||||
|
||||
# Call the private method directly to focus on minimal subset logic
|
||||
explainer._find_minimal_subset([101, 102], 200)
|
||||
|
||||
# Verify _record_explanation was called with the minimal subset [101]
|
||||
assert mock_record.called
|
||||
assert mock_record.call_args[0][0] == [101]
|
||||
|
||||
@patch.object(SlidingWindowExplainer, "_get_recommendations_after_removal")
|
||||
@patch.object(SlidingWindowExplainer, "_record_explanation")
|
||||
def test_find_minimal_subset_not_found(self, mock_record, mock_get_recs, explainer):
|
||||
"""Test behavior when no minimal subset is found."""
|
||||
# Configure mock so no subset affects recommendations
|
||||
mock_get_recs.return_value = [200, 201, 202] # Always includes target item
|
||||
|
||||
# Call the method
|
||||
explainer._find_minimal_subset([101, 102], 200)
|
||||
|
||||
# Verify _record_explanation was not called
|
||||
assert not mock_record.called
|
||||
|
||||
@patch.object(SlidingWindowExplainer, "_calculate_item_intensity")
|
||||
@patch.object(SlidingWindowExplainer, "_calculate_user_intensity")
|
||||
def test_record_explanation(
|
||||
self, mock_user_intensity, mock_item_intensity, explainer, capfd
|
||||
):
|
||||
"""Test recording an explanation."""
|
||||
# Configure mocks
|
||||
mock_item_intensity.return_value = [0.5, 0.7]
|
||||
mock_user_intensity.return_value = [0.3, 0.6, 0.8]
|
||||
|
||||
# Call method
|
||||
explainer._record_explanation([101, 102], 200, 201)
|
||||
|
||||
# Check explanation was stored
|
||||
assert explainer.explanations_found[explainer.calls]["items"] == [101, 102]
|
||||
|
||||
# Check print output
|
||||
out, _ = capfd.readouterr()
|
||||
assert "If the group had not interacted with these items" in out
|
||||
|
||||
def test_calculate_average_item_intensity_score(self, mock_data_reader):
|
||||
"""Test calculation of average item intensity."""
|
||||
# Use static method directly
|
||||
result = SlidingWindowExplainer._calculate_average_item_intensity_score(
|
||||
explanation=[101, 102], members=[1, 2, 3], data=mock_data_reader
|
||||
)
|
||||
|
||||
# Expected:
|
||||
# - Item 101 has interactions with users 1 and 2 (2/3 = 0.67)
|
||||
# - Item 102 has interactions with users 1 and 3 (2/3 = 0.67)
|
||||
assert len(result) == 2
|
||||
assert result[0] == pytest.approx(2 / 3)
|
||||
assert result[1] == pytest.approx(2 / 3)
|
||||
|
||||
def test_calculate_user_intensity_score(self, mock_data_reader):
|
||||
"""Test calculation of user intensity."""
|
||||
# Use static method directly
|
||||
result = SlidingWindowExplainer._calculate_user_intensity_score(
|
||||
explanation_items=[101, 102, 103], members=[1, 2, 3], data=mock_data_reader
|
||||
)
|
||||
|
||||
# Expected:
|
||||
# - User 1 interacted with items 101, 102, 103 (3/3 = 1.0)
|
||||
# - User 2 interacted with item 101 only (1/3 = 0.33)
|
||||
# - User 3 interacted with items 102, 103 (2/3 = 0.67)
|
||||
assert len(result) == 3
|
||||
assert result[0] == pytest.approx(1.0)
|
||||
assert result[1] == pytest.approx(1 / 3)
|
||||
assert result[2] == pytest.approx(2 / 3)
|
||||
|
||||
def test_create_data_reader_and_prepare(self, explainer, mock_data_reader):
|
||||
"""Test creating and preparing a new DataReader with modified data."""
|
||||
with patch(
|
||||
"pygrex.explain.groups.sliding_window_explainer.DataReader"
|
||||
) as mock_reader_class:
|
||||
# Set up mock DataReader class
|
||||
mock_new_reader = Mock(spec=DataReader)
|
||||
mock_reader_class.return_value = mock_new_reader
|
||||
|
||||
# Call method
|
||||
result = explainer._create_data_reader_and_prepare(mock_data_reader.dataset)
|
||||
|
||||
# Check DataReader was created and methods were called
|
||||
assert mock_reader_class.called
|
||||
assert mock_new_reader.make_consecutive_ids_in_dataset.called
|
||||
assert mock_new_reader.binarize.called
|
||||
assert result == mock_new_reader
|
||||
|
||||
def test_retrain_model(self, explainer, mock_data_reader):
|
||||
"""Test retraining the model with modified data."""
|
||||
model = explainer.model
|
||||
result = explainer._retrain_model(mock_data_reader)
|
||||
|
||||
# Check that fit was called and the model was returned
|
||||
assert model.fit.called
|
||||
assert model.fit.call_args[0][0] == mock_data_reader
|
||||
assert result == model
|
||||
|
||||
def test_max_calls_limit(self, explainer):
|
||||
"""Test that find_explanation respects max_calls limit."""
|
||||
# Set a very low max_calls value
|
||||
explainer.max_calls = 1
|
||||
|
||||
# Mock necessary methods to isolate test
|
||||
explainer._test_window_removal = Mock(return_value=False)
|
||||
|
||||
# Call find_explanation with required args
|
||||
items_rated_by_group = [101, 102, 103]
|
||||
group_predictions = {1: {101: 4.0}}
|
||||
top_recommendation = 200
|
||||
ranking_weights = {"popularity": 1, "intensity": 1, "rating": 1, "relevance": 1, "trend": 0}
|
||||
result = explainer.find_explanation(items_rated_by_group, group_predictions, top_recommendation, ranking_weights)
|
||||
|
||||
# Verify only one call was made
|
||||
assert explainer.calls == 1
|
||||
assert result == {}
|
||||
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader import DataReader
|
||||
from pygrex.models import ALS
|
||||
from pygrex.recommender import Recommender
|
||||
from pygrex.evaluator import ModelEvaluator, Splitter
|
||||
from pygrex.explain import KNNPostHocExplainer
|
||||
# from threadpoolctl import threadpool_limits
|
||||
|
||||
|
||||
# # Fixture to replace setUp method
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def limit_blas_threads():
|
||||
# threadpool_limits(1, "blas")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def setup_data():
|
||||
# Initialize ALS model
|
||||
als = ALS(**cfg.model.als)
|
||||
|
||||
# Prepare data
|
||||
data = DataReader(**cfg.data.ml100k)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
|
||||
return als, data
|
||||
|
||||
|
||||
def test_explain_als(setup_data):
|
||||
als, data = setup_data
|
||||
|
||||
# Split data
|
||||
sp = Splitter()
|
||||
train, test = sp.split_leave_n_out(data, n=1)
|
||||
|
||||
# Test ALS model fitting (fit should complete without raising)
|
||||
als.fit(train)
|
||||
|
||||
# Generate recommendations
|
||||
recommender = Recommender(train, als)
|
||||
# Recommend on a small, valid subset of items
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
|
||||
# Pipeline ran without exceptions; evaluation moved elsewhere
|
||||
|
||||
# explainer = ALSExplainer(als, recommendations, data)
|
||||
# explainer.explain_recommendations()
|
||||
|
||||
# KNN Post Hoc Explainer (skipped in this smoke test)
|
||||
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.als_model import ALS
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
als = ALS(**cfg.model.als)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
als.fit(data)
|
||||
return als, data
|
||||
|
||||
|
||||
def test_train_recommend_als(setup_data):
|
||||
als, data = setup_data
|
||||
recommender = Recommender(data, als)
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.autoencoder_model import ExplAutoencoderTorch
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
autoencoder = ExplAutoencoderTorch(**cfg.model.autoencoder)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return autoencoder, data
|
||||
|
||||
|
||||
def test_train_autoencoder(setup_data):
|
||||
autoencoder, data = setup_data
|
||||
autoencoder.fit(data)
|
||||
# Directly test predict on a valid single item after fit
|
||||
_ = autoencoder.predict(0, 0)
|
||||
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models import EMFModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
emf = EMFModel(**cfg.model.emf)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
return emf, data
|
||||
|
||||
|
||||
def test_train_emf(setup_data):
|
||||
emf, data = setup_data
|
||||
emf.fit(data)
|
||||
recommender = Recommender(data, emf)
|
||||
recommender.recommend_all()
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.gmf_model import GMFModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
gmf = GMFModel(**cfg.model.gmf)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return gmf, data
|
||||
|
||||
|
||||
def test_train_gmf(setup_data):
|
||||
gmf, data = setup_data
|
||||
gmf.fit(data)
|
||||
recommender = Recommender(data, gmf)
|
||||
recommender.recommend_all()
|
||||
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.als_model import ALS
|
||||
from pygrex.models.bpr_model import BPR
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
als = ALS(**cfg.model.als)
|
||||
bpr = BPR(**cfg.model.bpr)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return als, bpr, data
|
||||
|
||||
|
||||
def test_train_als(setup_data):
|
||||
als, _, data = setup_data
|
||||
als.fit(data)
|
||||
recommender = Recommender(data, als)
|
||||
max_valid = als.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
|
||||
|
||||
def test_train_bpr(setup_data):
|
||||
_, bpr, data = setup_data
|
||||
bpr.fit(data)
|
||||
recommender = Recommender(data, bpr)
|
||||
max_valid = bpr.model.item_factors.shape[0]
|
||||
item_pool = list(range(min(max_valid, 50)))
|
||||
_ = recommender.recommend(user_id=0, target_item_id=item_pool)
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from pygrex.config import cfg
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.mlp_model import MLPModel
|
||||
from pygrex.recommender import Recommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_data():
|
||||
mlp = MLPModel(**cfg.model.mlp)
|
||||
data = DataReader(**cfg.data.testdata)
|
||||
data.make_consecutive_ids_in_dataset()
|
||||
data.binarize()
|
||||
return mlp, data
|
||||
|
||||
|
||||
def test_train_mlp(setup_data):
|
||||
mlp, data = setup_data
|
||||
mlp.fit(data)
|
||||
# Directly predict for a single valid item after fit
|
||||
_ = mlp.predict(0, 0)
|
||||
@@ -0,0 +1,458 @@
|
||||
import itertools
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.models.recommender_model import RecommenderModel
|
||||
from pygrex.recommender.group_recommender import GroupRecommender
|
||||
from pygrex.utils.aggregation_strategy import AggregationStrategy, ScoreAggregator
|
||||
from pygrex.utils.scale import Scale
|
||||
|
||||
|
||||
class TestGroupRecommender:
|
||||
"""Test suite for the GroupRecommender class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_reader(self):
|
||||
"""Create a mock DataReader."""
|
||||
mock_data = MagicMock(spec=DataReader)
|
||||
# Setup mock dataset
|
||||
mock_data.dataset = MagicMock()
|
||||
return mock_data
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
"""Create a mock RecommenderModel."""
|
||||
return MagicMock(spec=RecommenderModel)
|
||||
|
||||
@pytest.fixture
|
||||
def group_recommender(self, mock_data_reader):
|
||||
"""Create a GroupRecommender instance with mock data."""
|
||||
return GroupRecommender(mock_data_reader)
|
||||
|
||||
def test_init(self, mock_data_reader):
|
||||
"""Test the initialization of GroupRecommender."""
|
||||
recommender = GroupRecommender(mock_data_reader)
|
||||
|
||||
assert recommender.data == mock_data_reader
|
||||
assert recommender._group_predictions is None
|
||||
assert recommender._members is None
|
||||
assert recommender._item_pool is None
|
||||
assert recommender._model is None
|
||||
assert recommender._aggregation_strategy is None
|
||||
assert recommender._score_aggregator is None
|
||||
assert recommender._aggregated_scores is None
|
||||
|
||||
def test_setup_recommendation(
|
||||
self, group_recommender, mock_model, mock_data_reader
|
||||
):
|
||||
"""Test setup_recommendation method."""
|
||||
# Arrange
|
||||
members = [1, 2, 3]
|
||||
item_ids = [101, 102, 103, 104, 105]
|
||||
mock_item_pool = np.array([101, 103, 105])
|
||||
aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS
|
||||
mock_data_reader.dataset.__getitem__().unique.return_value = item_ids
|
||||
|
||||
# Mock methods
|
||||
group_recommender.get_non_interacted_items_for_recommendation = MagicMock(
|
||||
return_value=mock_item_pool
|
||||
)
|
||||
group_recommender._generate_group_predictions = MagicMock(
|
||||
return_value={1: {101: 4.5}, 2: {103: 3.2}, 3: {105: 4.0}}
|
||||
)
|
||||
group_recommender._aggregate_group_scores = MagicMock(
|
||||
return_value={101: 4.5, 103: 3.2, 105: 4.0}
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch.object(GroupRecommender, "_get_max_valid_item_id", return_value=100000):
|
||||
group_recommender.setup_recommendation(
|
||||
mock_model, members, mock_data_reader, aggregation_strategy
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert group_recommender._members == members
|
||||
assert group_recommender._model == mock_model
|
||||
assert group_recommender._aggregation_strategy == aggregation_strategy
|
||||
assert isinstance(group_recommender._score_aggregator, ScoreAggregator)
|
||||
assert np.array_equal(group_recommender._item_pool, mock_item_pool)
|
||||
group_recommender.get_non_interacted_items_for_recommendation.assert_called_once_with(
|
||||
group_recommender.data, item_ids, members
|
||||
)
|
||||
group_recommender._generate_group_predictions.assert_called_once()
|
||||
group_recommender._aggregate_group_scores.assert_called_once()
|
||||
|
||||
def test_setup_recommendation_with_most_respected_person(
|
||||
self, group_recommender, mock_model, mock_data_reader
|
||||
):
|
||||
"""Test setup_recommendation method with most respected person."""
|
||||
# Arrange
|
||||
members = [1, 2, 3]
|
||||
item_ids = [101, 102, 103, 104, 105]
|
||||
mock_item_pool = np.array([101, 103, 105])
|
||||
aggregation_strategy = AggregationStrategy.MOST_RESPECTED_PERSON
|
||||
most_respected_person = 1
|
||||
mock_data_reader.dataset.__getitem__().unique.return_value = item_ids
|
||||
|
||||
# Mock methods
|
||||
group_recommender.get_non_interacted_items_for_recommendation = MagicMock(
|
||||
return_value=mock_item_pool
|
||||
)
|
||||
group_recommender._generate_group_predictions = MagicMock(
|
||||
return_value={1: {101: 4.5}, 2: {103: 3.2}, 3: {105: 4.0}}
|
||||
)
|
||||
group_recommender._aggregate_group_scores = MagicMock(
|
||||
return_value={101: 4.5, 103: 3.2, 105: 4.0}
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch.object(GroupRecommender, "_get_max_valid_item_id", return_value=100000):
|
||||
group_recommender.setup_recommendation(
|
||||
mock_model,
|
||||
members,
|
||||
mock_data_reader,
|
||||
aggregation_strategy,
|
||||
most_respected_person,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
group_recommender._score_aggregator.most_respected_person
|
||||
== most_respected_person
|
||||
)
|
||||
|
||||
def test_generate_group_predictions(self, group_recommender, mock_model):
|
||||
"""Test _generate_group_predictions method."""
|
||||
# Arrange
|
||||
members = [1, 2]
|
||||
item_pool = np.array([101, 103])
|
||||
group_recommender._members = members
|
||||
group_recommender._model = mock_model
|
||||
group_recommender._item_pool = item_pool
|
||||
|
||||
# Mock generate_recommendation to return different predictions for each user
|
||||
group_recommender.generate_recommendation = MagicMock(
|
||||
side_effect=[
|
||||
{101: 4.5, 103: 3.8}, # User 1's predictions
|
||||
{101: 3.2, 103: 4.7}, # User 2's predictions
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
result = group_recommender._generate_group_predictions()
|
||||
|
||||
# Assert
|
||||
expected = {1: {101: 4.5, 103: 3.8}, 2: {101: 3.2, 103: 4.7}}
|
||||
assert result == expected
|
||||
assert group_recommender.generate_recommendation.call_count == 2
|
||||
|
||||
def test_generate_group_predictions_error(self, group_recommender):
|
||||
"""Test _generate_group_predictions method raises error when setup is incomplete."""
|
||||
# Arrange - incomplete setup
|
||||
group_recommender._members = [1, 2]
|
||||
group_recommender._model = None # Missing model
|
||||
group_recommender._item_pool = np.array([101, 103])
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="You must call setup_recommendation before generating predictions",
|
||||
):
|
||||
group_recommender._generate_group_predictions()
|
||||
|
||||
def test_aggregate_group_scores(self, group_recommender):
|
||||
"""Test _aggregate_group_scores method."""
|
||||
# Arrange
|
||||
group_recommender._group_predictions = {
|
||||
1: {101: 4.5, 102: 3.8},
|
||||
2: {101: 3.2, 102: 4.7},
|
||||
}
|
||||
group_recommender._aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS
|
||||
mock_score_aggregator = MagicMock(spec=ScoreAggregator)
|
||||
mock_score_aggregator.aggregate_scores.return_value = {102: 4.25, 101: 3.85}
|
||||
group_recommender._score_aggregator = mock_score_aggregator
|
||||
|
||||
# Act
|
||||
result = group_recommender._aggregate_group_scores()
|
||||
|
||||
# Assert
|
||||
expected = {102: 4.25, 101: 3.85} # Already sorted by score descending
|
||||
assert result == expected
|
||||
mock_score_aggregator.aggregate_scores.assert_called_once_with(
|
||||
evaluations=group_recommender._group_predictions,
|
||||
strategy=AggregationStrategy.AVG_PREDICTIONS,
|
||||
rankings=None,
|
||||
)
|
||||
|
||||
def test_aggregate_group_scores_borda_count(self, group_recommender):
|
||||
"""Test _aggregate_group_scores method with Borda Count strategy."""
|
||||
# Arrange
|
||||
group_recommender._group_predictions = {
|
||||
1: {101: 4.5, 102: 3.8},
|
||||
2: {101: 3.2, 102: 4.7},
|
||||
}
|
||||
group_recommender._aggregation_strategy = AggregationStrategy.BORDA_COUNT
|
||||
mock_score_aggregator = MagicMock(spec=ScoreAggregator)
|
||||
mock_score_aggregator.aggregate_scores.return_value = {102: 2.0, 101: 1.0}
|
||||
group_recommender._score_aggregator = mock_score_aggregator
|
||||
|
||||
# Mock _create_rankings_from_predictions
|
||||
mock_rankings = {1: [101, 102], 2: [102, 101]}
|
||||
group_recommender._create_rankings_from_predictions = MagicMock(
|
||||
return_value=mock_rankings
|
||||
)
|
||||
|
||||
# Act
|
||||
result = group_recommender._aggregate_group_scores()
|
||||
|
||||
# Assert
|
||||
expected = {102: 2.0, 101: 1.0} # Already sorted by score descending
|
||||
assert result == expected
|
||||
group_recommender._create_rankings_from_predictions.assert_called_once()
|
||||
mock_score_aggregator.aggregate_scores.assert_called_once_with(
|
||||
evaluations=group_recommender._group_predictions,
|
||||
strategy=AggregationStrategy.BORDA_COUNT,
|
||||
rankings=mock_rankings,
|
||||
)
|
||||
|
||||
def test_aggregate_group_scores_error(self, group_recommender):
|
||||
"""Test _aggregate_group_scores method raises error when setup is incomplete."""
|
||||
# Arrange - incomplete setup
|
||||
group_recommender._group_predictions = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="You must call setup_recommendation before aggregating scores",
|
||||
):
|
||||
group_recommender._aggregate_group_scores()
|
||||
|
||||
def test_create_rankings_from_predictions(self, group_recommender):
|
||||
"""Test _create_rankings_from_predictions method."""
|
||||
# Arrange
|
||||
group_recommender._group_predictions = {
|
||||
1: {101: 4.5, 102: 3.8, 103: 2.1},
|
||||
2: {101: 3.2, 102: 4.7, 103: 3.9},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = group_recommender._create_rankings_from_predictions()
|
||||
|
||||
# Assert
|
||||
expected = {
|
||||
1: [101, 102, 103], # Sorted by score descending: 4.5, 3.8, 2.1
|
||||
2: [102, 103, 101], # Sorted by score descending: 4.7, 3.9, 3.2
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
def test_create_rankings_from_predictions_error(self, group_recommender):
|
||||
"""Test _create_rankings_from_predictions method raises error when predictions are not available."""
|
||||
# Arrange - no predictions
|
||||
group_recommender._group_predictions = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Group predictions not available"):
|
||||
group_recommender._create_rankings_from_predictions()
|
||||
|
||||
def test_get_non_interacted_items(self, group_recommender, mock_data_reader):
|
||||
"""Test get_non_interacted_items_for_recommendation method."""
|
||||
# Arrange
|
||||
members = [1, 2]
|
||||
all_items = [101, 102, 103, 104, 105]
|
||||
|
||||
# Setup mock data
|
||||
interacted_items = np.array([102, 104])
|
||||
mock_data_reader.dataset.loc = MagicMock()
|
||||
mock_data_reader.dataset.loc.__getitem__.return_value.unique.return_value = (
|
||||
interacted_items
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"numpy.setdiff1d", return_value=np.array([101, 103, 105])
|
||||
) as mock_setdiff:
|
||||
result = group_recommender.get_non_interacted_items_for_recommendation(
|
||||
mock_data_reader, all_items, members
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert np.array_equal(result, np.array([101, 103, 105]))
|
||||
mock_setdiff.assert_called_once_with(
|
||||
all_items, interacted_items, assume_unique=True
|
||||
)
|
||||
|
||||
def test_generate_recommendation(
|
||||
self, group_recommender, mock_data_reader, mock_model
|
||||
):
|
||||
"""Test generate_recommendation method."""
|
||||
# Arrange
|
||||
member = "1" # Test string conversion
|
||||
member_id_int = 1
|
||||
new_member_id = 101 # Mapped internal ID
|
||||
item_pool = [201, 202]
|
||||
|
||||
# Setup mocks
|
||||
mock_data_reader.get_new_user_id.return_value = new_member_id
|
||||
mock_data_reader.get_original_item_id.side_effect = (
|
||||
lambda x: x + 1000
|
||||
) # Simple mapping function
|
||||
|
||||
mock_model.predict.return_value = [3.5, 4.2] # Predictions for the two items
|
||||
|
||||
# Mock Scale.linear
|
||||
with patch(
|
||||
"pygrex.utils.scale.Scale.linear", return_value=np.array([3.0, 4.0])
|
||||
) as mock_scale:
|
||||
# Act
|
||||
with patch.object(GroupRecommender, "_get_max_valid_item_id", return_value=100000):
|
||||
result = group_recommender.generate_recommendation(
|
||||
mock_model, member, item_pool, mock_data_reader
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_data_reader.get_new_user_id.assert_called_once_with(member_id_int)
|
||||
assert mock_model.predict.call_count == 1
|
||||
mock_scale.assert_called_once()
|
||||
|
||||
# Check if the result dict has the expected structure: {original_item_id: scaled_score}
|
||||
expected = {1202: 4.0, 1201: 3.0} # Sorted by score descending
|
||||
assert result == expected
|
||||
|
||||
def test_get_group_recommendations_all(self, group_recommender):
|
||||
"""Test get_group_recommendations method for returning all items."""
|
||||
# Arrange
|
||||
group_recommender._aggregated_scores = {102: 4.25, 101: 3.85}
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_group_recommendations()
|
||||
|
||||
# Assert
|
||||
expected = [102, 101] # All item IDs from the aggregated scores
|
||||
assert result == expected
|
||||
|
||||
def test_get_group_recommendations_top_k(self, group_recommender):
|
||||
"""Test get_group_recommendations method for returning top k items."""
|
||||
# Arrange
|
||||
group_recommender._aggregated_scores = {102: 4.25, 101: 3.85, 103: 3.2}
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_group_recommendations(top_k=2)
|
||||
|
||||
# Assert
|
||||
expected = [102, 101] # Top 2 items from the aggregated scores
|
||||
assert result == expected
|
||||
|
||||
def test_get_group_recommendations_top_one(self, group_recommender):
|
||||
"""Test get_group_recommendations method for returning only the top item."""
|
||||
# Arrange
|
||||
group_recommender._aggregated_scores = {102: 4.25, 101: 3.85}
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_group_recommendations(top_k=1)
|
||||
|
||||
# Assert
|
||||
expected = 102 # The top item ID
|
||||
assert result == expected
|
||||
|
||||
def test_get_group_recommendations_error(self, group_recommender):
|
||||
"""Test get_group_recommendations method raises error when setup is incomplete."""
|
||||
# Arrange - incomplete setup
|
||||
group_recommender._aggregated_scores = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="You must call setup_recommendation before getting recommendations",
|
||||
):
|
||||
group_recommender.get_group_recommendations()
|
||||
|
||||
def test_get_top_recommendation(self, group_recommender):
|
||||
"""Test get_top_recommendation method."""
|
||||
# Arrange
|
||||
top_item = 102
|
||||
group_recommender.get_group_recommendations = MagicMock(return_value=top_item)
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_top_recommendation()
|
||||
|
||||
# Assert
|
||||
assert result == top_item
|
||||
group_recommender.get_group_recommendations.assert_called_once_with(top_k=1)
|
||||
|
||||
def test_get_recommendation_scores(self, group_recommender):
|
||||
"""Test get_recommendation_scores method."""
|
||||
# Arrange
|
||||
group_recommender._aggregated_scores = {103: 3.5, 101: 3.5, 102: 3.5}
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_recommendation_scores()
|
||||
|
||||
# Assert
|
||||
expected = {103: 3.5, 101: 3.5, 102: 3.5}
|
||||
assert result == expected
|
||||
|
||||
def test_get_recommendation_scores_error(self, group_recommender):
|
||||
"""Test get_recommendation_scores method raises error when setup is incomplete."""
|
||||
# Arrange - incomplete setup
|
||||
group_recommender._aggregated_scores = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="You must call setup_recommendation before getting recommendation scores",
|
||||
):
|
||||
group_recommender.get_recommendation_scores()
|
||||
|
||||
def test_get_aggregation_strategy(self, group_recommender):
|
||||
"""Test get_aggregation_strategy method."""
|
||||
# Arrange
|
||||
strategy = AggregationStrategy.LEAST_MISERY
|
||||
group_recommender._aggregation_strategy = strategy
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_aggregation_strategy()
|
||||
|
||||
# Assert
|
||||
assert result == strategy
|
||||
|
||||
def test_get_group_members(self, group_recommender):
|
||||
"""Test get_group_members method."""
|
||||
# Arrange
|
||||
members = [1, 2, 3]
|
||||
group_recommender._members = members
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_group_members()
|
||||
|
||||
# Assert
|
||||
assert result == members
|
||||
# Ensure it returns a copy, not the original
|
||||
assert result is not members
|
||||
|
||||
def test_get_individual_predictions(self, group_recommender):
|
||||
"""Test get_individual_predictions method."""
|
||||
# Arrange
|
||||
predictions = {1: {101: 4.5, 102: 3.8}, 2: {101: 3.2, 102: 4.7}}
|
||||
group_recommender._group_predictions = predictions
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_individual_predictions()
|
||||
|
||||
# Assert
|
||||
assert result == predictions
|
||||
# Ensure it returns a copy, not the original
|
||||
assert result is not predictions
|
||||
|
||||
def test_get_individual_predictions_none(self, group_recommender):
|
||||
"""Test get_individual_predictions method when predictions are None."""
|
||||
# Arrange
|
||||
group_recommender._group_predictions = None
|
||||
|
||||
# Act
|
||||
result = group_recommender.get_individual_predictions()
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
@@ -0,0 +1,614 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.utils.association_rules import AssociationRules
|
||||
|
||||
|
||||
class TestAssociationRulesInitialization:
|
||||
"""Test class for AssociationRules initialization."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create mock DataReader with sample data
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.sample_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 3, 3, 3, 4, 4],
|
||||
"itemId": ["A", "B", "C", "A", "B", "B", "C", "D", "A", "D"],
|
||||
"rating": [4.5, 3.0, 5.0, 4.0, 4.5, 3.5, 4.0, 4.5, 5.0, 3.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.sample_dataset
|
||||
|
||||
def test_init_with_default_parameters(self):
|
||||
"""Test initialization with default parameters."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
assert ar.data == self.mock_data_reader
|
||||
assert ar.min_support == 0.2
|
||||
assert ar.min_confidence == 0.2
|
||||
assert ar.rating_threshold == 4.0
|
||||
assert ar._frequent_itemsets is None
|
||||
assert ar._association_rules is None
|
||||
|
||||
def test_init_with_custom_parameters(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.1,
|
||||
min_confidence=0.3,
|
||||
rating_threshold=3.5,
|
||||
)
|
||||
|
||||
assert ar.min_support == 0.1
|
||||
assert ar.min_confidence == 0.3
|
||||
assert ar.rating_threshold == 3.5
|
||||
|
||||
def test_init_with_invalid_min_support(self):
|
||||
"""Test initialization with invalid min_support values."""
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=0)
|
||||
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=1.5)
|
||||
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=-0.1)
|
||||
|
||||
def test_init_with_invalid_min_confidence(self):
|
||||
"""Test initialization with invalid min_confidence values."""
|
||||
with pytest.raises(ValueError, match="min_confidence must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_confidence=0)
|
||||
|
||||
with pytest.raises(ValueError, match="min_confidence must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_confidence=2.0)
|
||||
|
||||
def test_init_with_invalid_rating_threshold(self):
|
||||
"""Test initialization with invalid rating_threshold."""
|
||||
with pytest.raises(ValueError, match="rating_threshold must be non-negative"):
|
||||
AssociationRules(self.mock_data_reader, rating_threshold=-1.0)
|
||||
|
||||
|
||||
class TestAssociationRulesValidation:
|
||||
"""Test class for parameter validation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.mock_data_reader.dataset = pd.DataFrame(
|
||||
{"userId": [1, 2], "itemId": ["A", "B"], "rating": [4.0, 5.0]}
|
||||
)
|
||||
|
||||
def test_validate_parameters_valid_inputs(self):
|
||||
"""Test parameter validation with valid inputs."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
# Should not raise any exception
|
||||
ar._validate_parameters(0.1, 0.2, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_support(self):
|
||||
"""Test parameter validation with invalid support values."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0, 0.5, 3.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(1.1, 0.5, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_confidence(self):
|
||||
"""Test parameter validation with invalid confidence values."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 0, 3.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 1.5, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_rating_threshold(self):
|
||||
"""Test parameter validation with invalid rating threshold."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 0.5, -1.0)
|
||||
|
||||
|
||||
class TestAssociationRulesTransactionPreparation:
|
||||
"""Test class for transaction preparation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
def test_prepare_transactions_normal_case(self):
|
||||
"""Test transaction preparation with normal data."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "C", "A", "B", "B", "C"],
|
||||
"rating": [4.5, 3.0, 5.0, 4.0, 4.5, 3.5, 4.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, rating_threshold=4.0)
|
||||
transactions = ar._prepare_transactions()
|
||||
|
||||
expected_transactions = [
|
||||
["A", "C"], # User 1: ratings 4.5, 5.0
|
||||
["A", "B"], # User 2: ratings 4.0, 4.5
|
||||
["C"], # User 3: rating 4.0
|
||||
]
|
||||
|
||||
assert len(transactions) == 3
|
||||
assert transactions == expected_transactions
|
||||
|
||||
def test_prepare_transactions_empty_after_filter(self):
|
||||
"""Test transaction preparation when no ratings meet threshold."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3],
|
||||
"itemId": ["A", "B", "C"],
|
||||
"rating": [2.0, 3.0, 3.5],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, rating_threshold=4.0)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No interactions found with rating >= 4.0"
|
||||
):
|
||||
ar._prepare_transactions()
|
||||
|
||||
def test_prepare_transactions_string_conversion(self):
|
||||
"""Test that movie IDs are converted to strings."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1],
|
||||
"itemId": [123, 456], # Numeric movie IDs
|
||||
"rating": [4.0, 5.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
transactions = ar._prepare_transactions()
|
||||
|
||||
assert transactions == [["123", "456"]]
|
||||
assert all(
|
||||
isinstance(item, str)
|
||||
for transaction in transactions
|
||||
for item in transaction
|
||||
)
|
||||
|
||||
|
||||
class TestAssociationRulesMining:
|
||||
"""Test class for frequent itemsets mining methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
@patch("pygrex.utils.association_rules.TransactionEncoder")
|
||||
def test_mine_frequent_itemsets_success(self, mock_encoder_class, mock_fpgrowth):
|
||||
"""Test successful frequent itemsets mining."""
|
||||
# Mock TransactionEncoder
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder_class.return_value = mock_encoder
|
||||
mock_encoder.fit_transform.return_value = np.array(
|
||||
[[True, False], [False, True]]
|
||||
)
|
||||
mock_encoder.columns_ = ["A", "B"]
|
||||
|
||||
# Mock fpgrowth result
|
||||
mock_frequent_itemsets = pd.DataFrame(
|
||||
{"support": [0.3, 0.4], "itemsets": [{"A"}, {"B"}]}
|
||||
)
|
||||
mock_fpgrowth.return_value = mock_frequent_itemsets
|
||||
|
||||
dataset = pd.DataFrame(
|
||||
{"userId": [1, 2], "itemId": ["A", "B"], "rating": [4.0, 5.0]}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, min_support=0.2)
|
||||
transactions = [["A"], ["B"]]
|
||||
|
||||
result = ar._mine_frequent_itemsets(transactions) # type: ignore
|
||||
|
||||
assert not result.empty
|
||||
mock_fpgrowth.assert_called_once()
|
||||
mock_encoder.fit_transform.assert_called_once_with(transactions)
|
||||
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
@patch("pygrex.utils.association_rules.TransactionEncoder")
|
||||
def test_mine_frequent_itemsets_empty_result(
|
||||
self, mock_encoder_class, mock_fpgrowth
|
||||
):
|
||||
"""Test frequent itemsets mining with empty result."""
|
||||
# Mock TransactionEncoder
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder_class.return_value = mock_encoder
|
||||
mock_encoder.fit_transform.return_value = np.array([[True, False]])
|
||||
mock_encoder.columns_ = ["A", "B"]
|
||||
|
||||
# Mock empty fpgrowth result
|
||||
mock_fpgrowth.return_value = pd.DataFrame()
|
||||
|
||||
dataset = pd.DataFrame({"userId": [1], "itemId": ["A"], "rating": [4.0]})
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, min_support=0.9)
|
||||
transactions = [["A"]]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No frequent itemsets found with min_support=0.9"
|
||||
):
|
||||
ar._mine_frequent_itemsets(transactions) # type: ignore
|
||||
|
||||
|
||||
class TestAssociationRulesGeneration:
|
||||
"""Test class for association rules generation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.sample_frequent_itemsets = pd.DataFrame(
|
||||
{"support": [0.3, 0.4, 0.2], "itemsets": [{"A"}, {"B"}, {"A", "B"}]}
|
||||
)
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
def test_generate_association_rules_success(self, mock_association_rules):
|
||||
"""Test successful association rules generation."""
|
||||
mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [{"A"}],
|
||||
"consequents": [{"B"}],
|
||||
"confidence": [0.8],
|
||||
"support": [0.2],
|
||||
}
|
||||
)
|
||||
mock_association_rules.return_value = mock_rules
|
||||
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
ar = AssociationRules(mock_data_reader, min_confidence=0.5)
|
||||
|
||||
result = ar._generate_association_rules(self.sample_frequent_itemsets)
|
||||
|
||||
assert not result.empty
|
||||
mock_association_rules.assert_called_once_with(
|
||||
self.sample_frequent_itemsets, metric="confidence", min_threshold=0.5
|
||||
)
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
def test_generate_association_rules_empty_result(self, mock_association_rules):
|
||||
"""Test association rules generation with empty result."""
|
||||
mock_association_rules.return_value = pd.DataFrame()
|
||||
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
ar = AssociationRules(mock_data_reader, min_confidence=0.9)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No association rules found with min_confidence=0.9"
|
||||
):
|
||||
ar._generate_association_rules(self.sample_frequent_itemsets)
|
||||
|
||||
|
||||
class TestAssociationRulesCompute:
|
||||
"""Test class for the main compute method."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.sample_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "A", "C", "B", "C"],
|
||||
"rating": [4.0, 5.0, 4.5, 4.0, 4.5, 5.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.sample_dataset
|
||||
|
||||
def test_compute_empty_dataset(self):
|
||||
"""Test compute method with empty dataset."""
|
||||
self.mock_data_reader.dataset = pd.DataFrame()
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset is empty"):
|
||||
ar.compute()
|
||||
|
||||
@patch.object(AssociationRules, "_generate_association_rules")
|
||||
@patch.object(AssociationRules, "_mine_frequent_itemsets")
|
||||
@patch.object(AssociationRules, "_prepare_transactions")
|
||||
def test_compute_success(self, mock_prepare, mock_mine, mock_generate):
|
||||
"""Test successful compute execution."""
|
||||
# Mock method returns
|
||||
mock_prepare.return_value = [["A", "B"], ["A", "C"]]
|
||||
mock_frequent_itemsets = pd.DataFrame({"support": [0.3], "itemsets": [{"A"}]})
|
||||
mock_mine.return_value = mock_frequent_itemsets
|
||||
mock_rules = pd.DataFrame(
|
||||
{"antecedents": [{"A"}], "consequents": [{"B"}], "confidence": [0.8]}
|
||||
)
|
||||
mock_generate.return_value = mock_rules
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
result = ar.compute()
|
||||
|
||||
assert result.equals(mock_rules)
|
||||
assert ar._frequent_itemsets.equals(mock_frequent_itemsets) # type: ignore
|
||||
assert ar._association_rules.equals(mock_rules) # type: ignore
|
||||
|
||||
mock_prepare.assert_called_once()
|
||||
mock_mine.assert_called_once_with([["A", "B"], ["A", "C"]])
|
||||
mock_generate.assert_called_once_with(mock_frequent_itemsets)
|
||||
|
||||
@patch.object(AssociationRules, "_prepare_transactions")
|
||||
def test_compute_no_transactions(self, mock_prepare):
|
||||
"""Test compute method when no transactions are found."""
|
||||
mock_prepare.return_value = []
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError, match="No transactions found after filtering"):
|
||||
ar.compute()
|
||||
|
||||
|
||||
class TestAssociationRulesAccessors:
|
||||
"""Test class for accessor methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
def test_get_frequent_itemsets_before_compute(self):
|
||||
"""Test getting frequent itemsets before compute is called."""
|
||||
result = self.ar.get_frequent_itemsets()
|
||||
assert result is None
|
||||
|
||||
def test_get_frequent_itemsets_after_compute(self):
|
||||
"""Test getting frequent itemsets after compute is called."""
|
||||
mock_itemsets = pd.DataFrame({"support": [0.3], "itemsets": [{"A"}]})
|
||||
self.ar._frequent_itemsets = mock_itemsets
|
||||
|
||||
result = self.ar.get_frequent_itemsets()
|
||||
assert result.equals(mock_itemsets) # type: ignore
|
||||
|
||||
|
||||
class TestAssociationRulesRecommendations:
|
||||
"""Test class for recommendation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
# Mock association rules
|
||||
self.mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [
|
||||
frozenset(["A"]),
|
||||
frozenset(["B"]),
|
||||
frozenset(["A", "B"]),
|
||||
],
|
||||
"consequents": [frozenset(["B"]), frozenset(["C"]), frozenset(["C"])],
|
||||
"confidence": [0.8, 0.6, 0.9],
|
||||
"lift": [1.2, 1.1, 1.5],
|
||||
"support": [0.4, 0.3, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
def test_get_recommendations_before_compute(self):
|
||||
"""Test getting recommendations before compute is called."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Must call compute\\(\\) before getting recommendations"
|
||||
):
|
||||
self.ar.get_recommendations_for_items(["A"])
|
||||
|
||||
def test_get_recommendations_empty_items(self):
|
||||
"""Test getting recommendations with empty items list."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
with pytest.raises(ValueError, match="Items list cannot be empty"):
|
||||
self.ar.get_recommendations_for_items([])
|
||||
|
||||
def test_get_recommendations_success(self):
|
||||
"""Test successful recommendation generation."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["A"], top_k=5)
|
||||
|
||||
assert not result.empty
|
||||
assert len(result) <= 5
|
||||
assert "confidence" in result.columns
|
||||
assert "antecedents" in result.columns
|
||||
assert "consequents" in result.columns
|
||||
|
||||
def test_get_recommendations_no_matching_rules(self):
|
||||
"""Test recommendations when no rules match the items."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["Z"]) # Item not in rules
|
||||
|
||||
assert result.empty
|
||||
|
||||
def test_get_recommendations_top_k_limit(self):
|
||||
"""Test that recommendations respect top_k limit."""
|
||||
# Create more rules than top_k
|
||||
extended_rules = pd.concat([self.mock_rules] * 5, ignore_index=True)
|
||||
self.ar._association_rules = extended_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["A"], top_k=2)
|
||||
|
||||
assert len(result) <= 2
|
||||
|
||||
|
||||
class TestAssociationRulesStringRepresentations:
|
||||
"""Test class for string representation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.mock_data_reader.dataset = pd.DataFrame(
|
||||
{"userId": [1], "itemId": ["A"], "rating": [4.0]}
|
||||
)
|
||||
|
||||
def test_str_representation(self):
|
||||
"""Test string representation of AssociationRules object."""
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.1,
|
||||
min_confidence=0.3,
|
||||
rating_threshold=3.5,
|
||||
)
|
||||
|
||||
expected = "AssociationRules(min_support=0.1, min_confidence=0.3, rating_threshold=3.5)"
|
||||
assert str(ar) == expected
|
||||
|
||||
def test_repr_representation(self):
|
||||
"""Test repr representation of AssociationRules object."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
expected = "AssociationRules(min_support=0.2, min_confidence=0.2, rating_threshold=4.0)"
|
||||
assert repr(ar) == expected
|
||||
|
||||
|
||||
class TestAssociationRulesIntegration:
|
||||
"""Integration tests for AssociationRules class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures for integration tests."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
# Create a more comprehensive dataset for integration testing
|
||||
self.integration_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5],
|
||||
"itemId": [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"A",
|
||||
"B",
|
||||
"D",
|
||||
"B",
|
||||
"C",
|
||||
"A",
|
||||
"C",
|
||||
"D",
|
||||
"B",
|
||||
"D",
|
||||
],
|
||||
"rating": [
|
||||
4.5,
|
||||
4.0,
|
||||
5.0,
|
||||
4.0,
|
||||
4.5,
|
||||
3.5,
|
||||
4.0,
|
||||
4.5,
|
||||
5.0,
|
||||
4.0,
|
||||
4.5,
|
||||
4.0,
|
||||
4.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.integration_dataset
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
def test_full_workflow_integration(self, mock_fpgrowth, mock_association_rules):
|
||||
"""Test the complete workflow from initialization to recommendations."""
|
||||
# Mock fpgrowth result
|
||||
mock_frequent_itemsets = pd.DataFrame(
|
||||
{
|
||||
"support": [0.4, 0.6, 0.3],
|
||||
"itemsets": [frozenset(["A"]), frozenset(["B"]), frozenset(["A", "B"])],
|
||||
}
|
||||
)
|
||||
mock_fpgrowth.return_value = mock_frequent_itemsets
|
||||
|
||||
# Mock association rules result
|
||||
mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [frozenset(["A"])],
|
||||
"consequents": [frozenset(["B"])],
|
||||
"confidence": [0.8],
|
||||
"lift": [1.2],
|
||||
"support": [0.3],
|
||||
}
|
||||
)
|
||||
mock_association_rules.return_value = mock_rules
|
||||
|
||||
# Initialize and compute
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.2,
|
||||
min_confidence=0.5,
|
||||
rating_threshold=4.0,
|
||||
)
|
||||
|
||||
# Run compute
|
||||
rules = ar.compute()
|
||||
|
||||
# Verify results
|
||||
assert not rules.empty
|
||||
assert ar.get_frequent_itemsets() is not None
|
||||
|
||||
# Test recommendations
|
||||
recommendations = ar.get_recommendations_for_items(["A"])
|
||||
assert isinstance(recommendations, pd.DataFrame)
|
||||
|
||||
|
||||
# Pytest configuration and fixtures
|
||||
@pytest.fixture
|
||||
def sample_data_reader():
|
||||
"""Fixture providing a sample DataReader for tests."""
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
mock_data_reader.dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "A", "C", "B", "C"],
|
||||
"rating": [4.0, 5.0, 4.5, 4.0, 4.5, 5.0],
|
||||
}
|
||||
)
|
||||
return mock_data_reader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def association_rules_instance(sample_data_reader):
|
||||
"""Fixture providing an AssociationRules instance for tests."""
|
||||
return AssociationRules(sample_data_reader)
|
||||
|
||||
|
||||
# Parametrized tests
|
||||
@pytest.mark.parametrize(
|
||||
"support,confidence,threshold,should_raise",
|
||||
[
|
||||
(0.1, 0.2, 4.0, False),
|
||||
(0.0, 0.2, 4.0, True),
|
||||
(1.1, 0.2, 4.0, True),
|
||||
(0.1, 0.0, 4.0, True),
|
||||
(0.1, 1.5, 4.0, True),
|
||||
(0.1, 0.2, -1.0, True),
|
||||
],
|
||||
)
|
||||
def test_parameter_validation_parametrized(
|
||||
sample_data_reader, support, confidence, threshold, should_raise
|
||||
):
|
||||
"""Parametrized test for parameter validation."""
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
AssociationRules(sample_data_reader, support, confidence, threshold)
|
||||
else:
|
||||
ar = AssociationRules(sample_data_reader, support, confidence, threshold)
|
||||
assert ar.min_support == support
|
||||
assert ar.min_confidence == confidence
|
||||
assert ar.rating_threshold == threshold
|
||||
@@ -0,0 +1,135 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from pygrex.utils.scale import Scale
|
||||
|
||||
|
||||
class TestScale:
|
||||
"""Test suite for the Scale class."""
|
||||
|
||||
def test_quantile_basic(self):
|
||||
"""Test basic functionality of the quantile method."""
|
||||
raw_values = [3.0, 1.0, 4.0, 2.0, 6.0]
|
||||
scaled = Scale.quantile(raw_values)
|
||||
# With 5 elements, we expect ranks to map evenly across 1-5
|
||||
assert np.allclose(scaled, [3, 1, 4, 2, 5])
|
||||
|
||||
def test_quantile_empty(self):
|
||||
"""Test quantile method with empty input."""
|
||||
with pytest.raises(ValueError, match="Raw predictions array is empty"):
|
||||
Scale.quantile([])
|
||||
|
||||
def test_quantile_single_element(self):
|
||||
"""Test quantile method with a single element."""
|
||||
result = Scale.quantile([7.5])
|
||||
assert np.allclose(result, [3]) # Middle of the default range [1, 5]
|
||||
|
||||
def test_quantile_custom_range(self):
|
||||
"""Test quantile method with custom target range."""
|
||||
raw_values = [10, 20, 30, 40, 50]
|
||||
scaled = Scale.quantile(raw_values, target_min=0, target_max=1)
|
||||
# Should map to [0, 0.25, 0.5, 0.75, 1]
|
||||
assert np.allclose(scaled, [0, 0.25, 0.5, 0.75, 1])
|
||||
|
||||
def test_quantile_equal_values(self):
|
||||
"""Test quantile method with all equal values."""
|
||||
raw_values = [5, 5, 5, 5]
|
||||
scaled = Scale.quantile(raw_values)
|
||||
# All values should get the average rank (2.5), which maps to the middle of [1, 5]
|
||||
assert np.allclose(scaled, [3, 3, 3, 3])
|
||||
|
||||
def test_quantile_numpy_input(self):
|
||||
"""Test quantile method with numpy array input."""
|
||||
raw_values = np.array([3.0, 1.0, 4.0, 2.0, 6.0])
|
||||
scaled = Scale.quantile(raw_values)
|
||||
assert np.allclose(scaled, [3, 1, 4, 2, 5])
|
||||
|
||||
def test_linear_basic(self):
|
||||
"""Test basic functionality of the linear method."""
|
||||
raw_values = [2.0, 4.0, 6.0, 8.0, 10.0]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# Should map linearly from [2, 10] to [1, 5]
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
|
||||
def test_linear_empty(self):
|
||||
"""Test linear method with empty input."""
|
||||
with pytest.raises(ValueError, match="Raw predictions array is empty"):
|
||||
Scale.linear([])
|
||||
|
||||
def test_linear_single_element(self):
|
||||
"""Test linear method with a single element."""
|
||||
result = Scale.linear([7.5])
|
||||
assert np.allclose(result, [3]) # Middle of the default range [1, 5]
|
||||
|
||||
def test_linear_custom_range(self):
|
||||
"""Test linear method with custom target range."""
|
||||
raw_values = [0, 5, 10]
|
||||
scaled = Scale.linear(raw_values, target_min=0, target_max=100)
|
||||
assert np.allclose(scaled, [0, 50, 100])
|
||||
|
||||
def test_linear_custom_ref_range(self):
|
||||
"""Test linear method with custom reference range."""
|
||||
raw_values = [2, 5, 8]
|
||||
scaled = Scale.linear(raw_values, ref_min=0, ref_max=10)
|
||||
# Should map from [0, 10] to [1, 5] regardless of actual min/max
|
||||
expected = [1 + (2 / 10) * 4, 1 + (5 / 10) * 4, 1 + (8 / 10) * 4]
|
||||
assert np.allclose(scaled, expected)
|
||||
|
||||
def test_linear_with_outliers(self):
|
||||
"""Test linear method with outlier handling."""
|
||||
# Values with outliers
|
||||
raw_values = [5, 6, 7, 8, 20] # 20 is an outlier
|
||||
|
||||
# With outlier handling (default)
|
||||
scaled_with_handling = Scale.linear(raw_values)
|
||||
|
||||
# Without outlier handling
|
||||
scaled_without_handling = Scale.linear(raw_values, handle_outliers=False)
|
||||
|
||||
# The result with outlier handling should be different
|
||||
assert not np.allclose(scaled_with_handling, scaled_without_handling)
|
||||
|
||||
# The outlier should be scaled to the max value (5) without handling
|
||||
assert scaled_without_handling[-1] == 5
|
||||
|
||||
# With handling, the outlier should still be clamped to the max
|
||||
assert scaled_with_handling[-1] == 5
|
||||
|
||||
# But other values should be more spread out with handling
|
||||
assert np.max(scaled_with_handling[:-1]) > np.max(scaled_without_handling[:-1])
|
||||
|
||||
def test_linear_equal_values(self):
|
||||
"""Test linear method with all equal values."""
|
||||
raw_values = [7, 7, 7, 7]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# All equal values should map to the middle of target range
|
||||
assert np.allclose(scaled, [3, 3, 3, 3])
|
||||
|
||||
def test_linear_equal_ref_bounds(self):
|
||||
"""Test linear method with equal reference bounds."""
|
||||
raw_values = [5, 6, 7]
|
||||
scaled = Scale.linear(raw_values, ref_min=5, ref_max=5)
|
||||
# When ref bounds are equal, should map to middle of target range
|
||||
assert np.allclose(scaled, [3, 3, 3])
|
||||
|
||||
def test_linear_numpy_input(self):
|
||||
"""Test linear method with numpy array input."""
|
||||
raw_values = np.array([2.0, 4.0, 6.0, 8.0, 10.0])
|
||||
scaled = Scale.linear(raw_values)
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
|
||||
def test_linear_clipping(self):
|
||||
"""Test that linear scaling properly clips out-of-bounds values."""
|
||||
# Values outside the reference range
|
||||
raw_values = [0, 5, 10, 15]
|
||||
scaled = Scale.linear(raw_values, ref_min=5, ref_max=10)
|
||||
# Values below ref_min should be clipped to target_min
|
||||
assert scaled[0] == 1
|
||||
# Values above ref_max should be clipped to target_max
|
||||
assert scaled[3] == 5
|
||||
|
||||
def test_linear_negative_values(self):
|
||||
"""Test linear scaling with negative values."""
|
||||
raw_values = [-10, -5, 0, 5, 10]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# Should map [-10, 10] to [1, 5]
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
@@ -0,0 +1,406 @@
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch
|
||||
from typing import Dict, List
|
||||
from pygrex.utils.aggregation_strategy import (
|
||||
ScoreAggregator,
|
||||
AggregationStrategy, # type: ignore
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import TypeAlias
|
||||
|
||||
|
||||
UserID: TypeAlias = str
|
||||
ItemID: TypeAlias = str
|
||||
EvaluationScore: TypeAlias = float
|
||||
AggregatedScore: TypeAlias = float
|
||||
UserEvaluations: TypeAlias = Dict[UserID, Dict[ItemID, EvaluationScore]]
|
||||
UserRankings: TypeAlias = Dict[UserID, List[ItemID]]
|
||||
AggregatedScores: TypeAlias = Dict[ItemID, AggregatedScore]
|
||||
|
||||
|
||||
class TestScoreAggregatorUnittest(unittest.TestCase):
|
||||
"""Unit tests using unittest framework."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
self.aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
self.sample_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 4.5, "item_B": 3.0, "item_C": 5.0},
|
||||
"user2": {"item_A": 3.0, "item_B": 4.0, "item_C": 2.0},
|
||||
"user3": {"item_A": 4.0, "item_B": 2.0, "item_C": 3.0},
|
||||
}
|
||||
|
||||
self.sample_rankings: UserRankings = {
|
||||
"user1": ["item_C", "item_A", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_A", "item_C", "item_B"],
|
||||
}
|
||||
|
||||
self.empty_evaluations: UserEvaluations = {}
|
||||
|
||||
self.single_user_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 3.5, "item_B": 4.0}
|
||||
}
|
||||
|
||||
def test_init_with_mrp(self):
|
||||
"""Test initialization with most respected person."""
|
||||
aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
self.assertEqual(aggregator.most_respected_person, "user1")
|
||||
|
||||
def test_init_without_mrp(self):
|
||||
"""Test initialization without most respected person."""
|
||||
aggregator = ScoreAggregator()
|
||||
self.assertIsNone(aggregator.most_respected_person)
|
||||
|
||||
def test_empty_evaluations(self):
|
||||
"""Test aggregation with empty evaluations."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.empty_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_avg_predictions(self):
|
||||
"""Test average predictions aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: (4.5+3.0+4.0)/3 = 3.833..., item_B: (3.0+4.0+2.0)/3 = 3.0, item_C: (5.0+2.0+3.0)/3 = 3.333...
|
||||
self.assertAlmostEqual(result["item_A"], 3.833333333333333, places=5)
|
||||
self.assertAlmostEqual(result["item_B"], 3.0, places=5)
|
||||
self.assertAlmostEqual(result["item_C"], 3.333333333333333, places=5)
|
||||
|
||||
def test_least_misery(self):
|
||||
"""Test least misery aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.LEAST_MISERY, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: min(4.5, 3.0, 4.0) = 3.0, item_B: min(3.0, 4.0, 2.0) = 2.0, item_C: min(5.0, 2.0, 3.0) = 2.0
|
||||
self.assertEqual(result["item_A"], 3.0)
|
||||
self.assertEqual(result["item_B"], 2.0)
|
||||
self.assertEqual(result["item_C"], 2.0)
|
||||
|
||||
def test_most_pleasure(self):
|
||||
"""Test most pleasure aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_PLEASURE, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: max(4.5, 3.0, 4.0) = 4.5, item_B: max(3.0, 4.0, 2.0) = 4.0, item_C: max(5.0, 2.0, 3.0) = 5.0
|
||||
self.assertEqual(result["item_A"], 4.5)
|
||||
self.assertEqual(result["item_B"], 4.0)
|
||||
self.assertEqual(result["item_C"], 5.0)
|
||||
|
||||
def test_most_respected_person(self):
|
||||
"""Test most respected person aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: user1's evaluations: item_A: 4.5, item_B: 3.0, item_C: 5.0
|
||||
self.assertEqual(result["item_A"], 4.5)
|
||||
self.assertEqual(result["item_B"], 3.0)
|
||||
self.assertEqual(result["item_C"], 5.0)
|
||||
|
||||
def test_mrp_without_mrp_set(self):
|
||||
"""Test MRP strategy without setting most respected person."""
|
||||
aggregator = ScoreAggregator() # No MRP set
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Most respected person not specified", str(context.exception))
|
||||
|
||||
def test_mrp_user_not_in_evaluations(self):
|
||||
"""Test MRP strategy when MRP user is not in evaluations."""
|
||||
aggregator = ScoreAggregator(most_respected_person="nonexistent_user")
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("not found in evaluations", str(context.exception))
|
||||
|
||||
def test_additive_utilitarian(self):
|
||||
"""Test additive utilitarian aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.ADDITIVE_UTILITARIAN, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: 4.5+3.0+4.0 = 11.5, item_B: 3.0+4.0+2.0 = 9.0, item_C: 5.0+2.0+3.0 = 10.0
|
||||
self.assertEqual(result["item_A"], 11.5)
|
||||
self.assertEqual(result["item_B"], 9.0)
|
||||
self.assertEqual(result["item_C"], 10.0)
|
||||
|
||||
def test_multiplicative(self):
|
||||
"""Test multiplicative aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MULTIPLICATIVE, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: 4.5*3.0*4.0 = 54.0, item_B: 3.0*4.0*2.0 = 24.0, item_C: 5.0*2.0*3.0 = 30.0
|
||||
self.assertEqual(result["item_A"], 54.0)
|
||||
self.assertEqual(result["item_B"], 24.0)
|
||||
self.assertEqual(result["item_C"], 30.0)
|
||||
|
||||
def test_borda_count(self):
|
||||
"""Test Borda count aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
self.sample_rankings, # type: ignore
|
||||
)
|
||||
|
||||
# Expected scores based on rankings:
|
||||
# item_A: user1(1) + user2(1) + user3(2) = 4
|
||||
# item_B: user1(0) + user2(2) + user3(0) = 2
|
||||
# item_C: user1(2) + user2(0) + user3(1) = 3
|
||||
self.assertEqual(result["item_A"], 4.0)
|
||||
self.assertEqual(result["item_B"], 2.0)
|
||||
self.assertEqual(result["item_C"], 3.0)
|
||||
|
||||
def test_borda_count_without_rankings(self):
|
||||
"""Test Borda count without providing rankings."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Rankings required for Borda Count", str(context.exception))
|
||||
|
||||
def test_unknown_strategy(self):
|
||||
"""Test with unknown aggregation strategy."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
"invalid_strategy", # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Unknown aggregation strategy", str(context.exception))
|
||||
|
||||
def test_get_top_recommendation(self):
|
||||
"""Test getting top recommendation."""
|
||||
top_item = self.aggregator.get_top_recommendation(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_PLEASURE, # type: ignore
|
||||
)
|
||||
|
||||
# item_C has the highest max value (5.0)
|
||||
self.assertEqual(top_item, "item_C")
|
||||
|
||||
def test_get_top_recommendation_with_rankings(self):
|
||||
"""Test getting top recommendation with Borda count."""
|
||||
top_item = self.aggregator.get_top_recommendation(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
self.sample_rankings, # type: ignore
|
||||
)
|
||||
|
||||
# item_A has the highest Borda count (4.0)
|
||||
self.assertEqual(top_item, "item_A")
|
||||
|
||||
def test_single_user_evaluation(self):
|
||||
"""Test aggregation with single user."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.single_user_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# With single user, average should equal the original values
|
||||
self.assertEqual(result["item_A"], 3.5)
|
||||
self.assertEqual(result["item_B"], 4.0)
|
||||
|
||||
def test_missing_items_in_evaluations(self):
|
||||
"""Test with missing items in some user evaluations."""
|
||||
incomplete_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 4.0, "item_B": 3.0},
|
||||
"user2": {"item_A": 3.0, "item_C": 2.0}, # Missing item_B
|
||||
"user3": {"item_B": 2.0, "item_C": 3.0}, # Missing item_A
|
||||
}
|
||||
|
||||
result = self.aggregator.aggregate_scores(
|
||||
incomplete_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# Should handle missing items gracefully
|
||||
self.assertAlmostEqual(result["item_A"], 3.5, places=5) # (4.0 + 3.0) / 2
|
||||
self.assertAlmostEqual(result["item_B"], 2.5, places=5) # (3.0 + 2.0) / 2
|
||||
self.assertAlmostEqual(result["item_C"], 2.5, places=5) # (2.0 + 3.0) / 2
|
||||
|
||||
|
||||
class TestScoreAggregatorPytest:
|
||||
"""Unit tests using pytest framework."""
|
||||
|
||||
@pytest.fixture
|
||||
def aggregator(self):
|
||||
"""Fixture for ScoreAggregator instance."""
|
||||
return ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_evaluations(self):
|
||||
"""Fixture for sample evaluations."""
|
||||
return {
|
||||
"user1": {"item_A": 4.5, "item_B": 3.0, "item_C": 5.0},
|
||||
"user2": {"item_A": 3.0, "item_B": 4.0, "item_C": 2.0},
|
||||
"user3": {"item_A": 4.0, "item_B": 2.0, "item_C": 3.0},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_rankings(self):
|
||||
"""Fixture for sample rankings."""
|
||||
return {
|
||||
"user1": ["item_C", "item_A", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_A", "item_C", "item_B"],
|
||||
}
|
||||
|
||||
def test_avg_predictions_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test average predictions using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.AVG_PREDICTIONS
|
||||
)
|
||||
|
||||
assert abs(result["item_A"] - 3.833333333333333) < 1e-5
|
||||
assert abs(result["item_B"] - 3.0) < 1e-5
|
||||
assert abs(result["item_C"] - 3.333333333333333) < 1e-5
|
||||
|
||||
def test_least_misery_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test least misery using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.LEAST_MISERY
|
||||
)
|
||||
|
||||
assert result["item_A"] == 3.0
|
||||
assert result["item_B"] == 2.0
|
||||
assert result["item_C"] == 2.0
|
||||
|
||||
def test_borda_count_pytest(self, aggregator, sample_evaluations, sample_rankings):
|
||||
"""Test Borda count using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.BORDA_COUNT, sample_rankings
|
||||
)
|
||||
|
||||
assert result["item_A"] == 4.0
|
||||
assert result["item_B"] == 2.0
|
||||
assert result["item_C"] == 3.0
|
||||
|
||||
def test_mrp_error_pytest(self, sample_evaluations):
|
||||
"""Test MRP error handling using pytest."""
|
||||
aggregator = ScoreAggregator() # No MRP set
|
||||
|
||||
with pytest.raises(ValueError, match="Most respected person not specified"):
|
||||
aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.MOST_RESPECTED_PERSON
|
||||
)
|
||||
|
||||
def test_borda_count_error_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test Borda count error handling using pytest."""
|
||||
with pytest.raises(ValueError, match="Rankings required for Borda Count"):
|
||||
aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.BORDA_COUNT
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strategy,expected_top",
|
||||
[
|
||||
(AggregationStrategy.LEAST_MISERY, "item_A"),
|
||||
(AggregationStrategy.MOST_PLEASURE, "item_C"),
|
||||
(AggregationStrategy.ADDITIVE_UTILITARIAN, "item_A"),
|
||||
],
|
||||
)
|
||||
def test_top_recommendation_parametrized(
|
||||
self, aggregator, sample_evaluations, strategy, expected_top
|
||||
):
|
||||
"""Test top recommendations with parametrized testing."""
|
||||
top_item = aggregator.get_top_recommendation(sample_evaluations, strategy)
|
||||
assert top_item == expected_top
|
||||
|
||||
def test_empty_evaluations_pytest(self, aggregator):
|
||||
"""Test empty evaluations using pytest."""
|
||||
result = aggregator.aggregate_scores({}, AggregationStrategy.AVG_PREDICTIONS)
|
||||
assert result == {}
|
||||
|
||||
def test_multiplicative_with_zero(self, aggregator):
|
||||
"""Test multiplicative aggregation with zero values."""
|
||||
evaluations_with_zero = {
|
||||
"user1": {"item_A": 0.0, "item_B": 3.0},
|
||||
"user2": {"item_A": 4.0, "item_B": 2.0},
|
||||
}
|
||||
|
||||
result = aggregator.aggregate_scores(
|
||||
evaluations_with_zero, AggregationStrategy.MULTIPLICATIVE
|
||||
)
|
||||
|
||||
assert result["item_A"] == 0.0 # 0.0 * 4.0 = 0.0
|
||||
assert result["item_B"] == 6.0 # 3.0 * 2.0 = 6.0
|
||||
|
||||
|
||||
# Integration tests
|
||||
class TestScoreAggregatorIntegration(unittest.TestCase):
|
||||
"""Integration tests for ScoreAggregator."""
|
||||
|
||||
def test_all_strategies_consistency(self):
|
||||
"""Test that all strategies produce consistent results."""
|
||||
evaluations = {
|
||||
"user1": {"item_A": 5.0, "item_B": 3.0, "item_C": 4.0},
|
||||
"user2": {"item_A": 4.0, "item_B": 5.0, "item_C": 3.0},
|
||||
"user3": {"item_A": 3.0, "item_B": 4.0, "item_C": 5.0},
|
||||
}
|
||||
|
||||
rankings = {
|
||||
"user1": ["item_A", "item_C", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_C", "item_B", "item_A"],
|
||||
}
|
||||
|
||||
aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
# Test that all strategies return valid results
|
||||
strategies_without_rankings = [
|
||||
AggregationStrategy.AVG_PREDICTIONS,
|
||||
AggregationStrategy.LEAST_MISERY,
|
||||
AggregationStrategy.MOST_PLEASURE,
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON,
|
||||
AggregationStrategy.AVG_PREDICTIONS,
|
||||
AggregationStrategy.ADDITIVE_UTILITARIAN,
|
||||
AggregationStrategy.MULTIPLICATIVE,
|
||||
]
|
||||
|
||||
for strategy in strategies_without_rankings:
|
||||
result = aggregator.aggregate_scores(evaluations, strategy) # type: ignore
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertEqual(len(result), 3) # Should have 3 items
|
||||
for score in result.values():
|
||||
self.assertIsInstance(score, (int, float))
|
||||
|
||||
# Test Borda count separately
|
||||
borda_result = aggregator.aggregate_scores(
|
||||
evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
rankings, # type: ignore
|
||||
)
|
||||
self.assertIsInstance(borda_result, dict)
|
||||
self.assertEqual(len(borda_result), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run unittest tests
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
from typing import List, Any
|
||||
|
||||
from pygrex.utils.sliding_window import SlidingWindow
|
||||
|
||||
|
||||
class TestSlidingWindow:
|
||||
"""Test suite for the SlidingWindow class."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test the basic window sliding functionality."""
|
||||
data = [1, 2, 3, 4, 5]
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# First window
|
||||
assert sliding_window.get_next_window() == [1, 2, 3]
|
||||
# Second window
|
||||
assert sliding_window.get_next_window() == [2, 3, 4]
|
||||
# Third window
|
||||
assert sliding_window.get_next_window() == [3, 4, 5]
|
||||
# No more windows
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_window_equal_to_sequence_length(self):
|
||||
"""Test when window size equals the sequence length."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == [1, 2, 3]
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_window_larger_than_sequence(self):
|
||||
"""Test when window size is larger than the sequence length."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 4
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_empty_sequence(self):
|
||||
"""Test with an empty sequence."""
|
||||
data: List[Any] = []
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_invalid_window_size(self):
|
||||
"""Test with invalid window sizes."""
|
||||
data = [1, 2, 3, 4, 5]
|
||||
|
||||
# Test with zero window size
|
||||
with pytest.raises(ValueError):
|
||||
SlidingWindow(data, 0)
|
||||
|
||||
# Test with negative window size
|
||||
with pytest.raises(ValueError):
|
||||
SlidingWindow(data, -1)
|
||||
|
||||
def test_non_iterable_sequence(self):
|
||||
"""Test with a non-iterable object."""
|
||||
data = 123 # Integer is not iterable
|
||||
window_size = 2
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
SlidingWindow(data, window_size)
|
||||
|
||||
def test_reset_functionality(self):
|
||||
"""Test the reset functionality."""
|
||||
data = [1, 2, 3, 4]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# Get first two windows
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
assert sliding_window.get_next_window() == [2, 3]
|
||||
|
||||
# Reset and check if we get the first window again
|
||||
sliding_window.reset()
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
|
||||
def test_has_next(self):
|
||||
"""Test the has_next method."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.has_next() is True
|
||||
sliding_window.get_next_window() # Get first window
|
||||
assert sliding_window.has_next() is True
|
||||
sliding_window.get_next_window() # Get second window
|
||||
assert sliding_window.has_next() is False
|
||||
|
||||
def test_iterator_protocol(self):
|
||||
"""Test the iterator protocol implementation."""
|
||||
data = [1, 2, 3, 4]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# Using the class in a for loop
|
||||
windows = []
|
||||
for window in sliding_window:
|
||||
windows.append(window)
|
||||
|
||||
assert windows == [[1, 2], [2, 3], [3, 4]]
|
||||
|
||||
# After iteration, the index should be at the end
|
||||
assert sliding_window.has_next() is False
|
||||
|
||||
# Test that reset works after iteration
|
||||
sliding_window.reset()
|
||||
assert sliding_window.has_next() is True
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
|
||||
def test_len_functionality(self):
|
||||
"""Test the __len__ method."""
|
||||
# Normal case
|
||||
data = [1, 2, 3, 4, 5]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 4
|
||||
|
||||
# Window size equals sequence length
|
||||
window_size = 5
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 1
|
||||
|
||||
# Window size greater than sequence length
|
||||
window_size = 6
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 0
|
||||
|
||||
# Empty sequence
|
||||
data = []
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 0
|
||||
|
||||
def test_with_string_data(self):
|
||||
"""Test with string data to verify generic implementation."""
|
||||
data = "abcde"
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == "abc"
|
||||
assert sliding_window.get_next_window() == "bcd"
|
||||
assert sliding_window.get_next_window() == "cde"
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_multiple_data_types(self):
|
||||
"""Test with a list containing multiple data types."""
|
||||
data = [1, "two", 3.0, [4, 5], {"six": 6}]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == [1, "two"]
|
||||
assert sliding_window.get_next_window() == ["two", 3.0]
|
||||
assert sliding_window.get_next_window() == [3.0, [4, 5]]
|
||||
assert sliding_window.get_next_window() == [[4, 5], {"six": 6}]
|
||||
assert sliding_window.get_next_window() is None
|
||||
Reference in New Issue
Block a user