public code v1
This commit is contained in:
@@ -0,0 +1,614 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.utils.association_rules import AssociationRules
|
||||
|
||||
|
||||
class TestAssociationRulesInitialization:
|
||||
"""Test class for AssociationRules initialization."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create mock DataReader with sample data
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.sample_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 3, 3, 3, 4, 4],
|
||||
"itemId": ["A", "B", "C", "A", "B", "B", "C", "D", "A", "D"],
|
||||
"rating": [4.5, 3.0, 5.0, 4.0, 4.5, 3.5, 4.0, 4.5, 5.0, 3.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.sample_dataset
|
||||
|
||||
def test_init_with_default_parameters(self):
|
||||
"""Test initialization with default parameters."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
assert ar.data == self.mock_data_reader
|
||||
assert ar.min_support == 0.2
|
||||
assert ar.min_confidence == 0.2
|
||||
assert ar.rating_threshold == 4.0
|
||||
assert ar._frequent_itemsets is None
|
||||
assert ar._association_rules is None
|
||||
|
||||
def test_init_with_custom_parameters(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.1,
|
||||
min_confidence=0.3,
|
||||
rating_threshold=3.5,
|
||||
)
|
||||
|
||||
assert ar.min_support == 0.1
|
||||
assert ar.min_confidence == 0.3
|
||||
assert ar.rating_threshold == 3.5
|
||||
|
||||
def test_init_with_invalid_min_support(self):
|
||||
"""Test initialization with invalid min_support values."""
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=0)
|
||||
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=1.5)
|
||||
|
||||
with pytest.raises(ValueError, match="min_support must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_support=-0.1)
|
||||
|
||||
def test_init_with_invalid_min_confidence(self):
|
||||
"""Test initialization with invalid min_confidence values."""
|
||||
with pytest.raises(ValueError, match="min_confidence must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_confidence=0)
|
||||
|
||||
with pytest.raises(ValueError, match="min_confidence must be between 0 and 1"):
|
||||
AssociationRules(self.mock_data_reader, min_confidence=2.0)
|
||||
|
||||
def test_init_with_invalid_rating_threshold(self):
|
||||
"""Test initialization with invalid rating_threshold."""
|
||||
with pytest.raises(ValueError, match="rating_threshold must be non-negative"):
|
||||
AssociationRules(self.mock_data_reader, rating_threshold=-1.0)
|
||||
|
||||
|
||||
class TestAssociationRulesValidation:
|
||||
"""Test class for parameter validation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.mock_data_reader.dataset = pd.DataFrame(
|
||||
{"userId": [1, 2], "itemId": ["A", "B"], "rating": [4.0, 5.0]}
|
||||
)
|
||||
|
||||
def test_validate_parameters_valid_inputs(self):
|
||||
"""Test parameter validation with valid inputs."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
# Should not raise any exception
|
||||
ar._validate_parameters(0.1, 0.2, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_support(self):
|
||||
"""Test parameter validation with invalid support values."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0, 0.5, 3.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(1.1, 0.5, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_confidence(self):
|
||||
"""Test parameter validation with invalid confidence values."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 0, 3.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 1.5, 3.0)
|
||||
|
||||
def test_validate_parameters_invalid_rating_threshold(self):
|
||||
"""Test parameter validation with invalid rating threshold."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ar._validate_parameters(0.1, 0.5, -1.0)
|
||||
|
||||
|
||||
class TestAssociationRulesTransactionPreparation:
|
||||
"""Test class for transaction preparation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
def test_prepare_transactions_normal_case(self):
|
||||
"""Test transaction preparation with normal data."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "C", "A", "B", "B", "C"],
|
||||
"rating": [4.5, 3.0, 5.0, 4.0, 4.5, 3.5, 4.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, rating_threshold=4.0)
|
||||
transactions = ar._prepare_transactions()
|
||||
|
||||
expected_transactions = [
|
||||
["A", "C"], # User 1: ratings 4.5, 5.0
|
||||
["A", "B"], # User 2: ratings 4.0, 4.5
|
||||
["C"], # User 3: rating 4.0
|
||||
]
|
||||
|
||||
assert len(transactions) == 3
|
||||
assert transactions == expected_transactions
|
||||
|
||||
def test_prepare_transactions_empty_after_filter(self):
|
||||
"""Test transaction preparation when no ratings meet threshold."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3],
|
||||
"itemId": ["A", "B", "C"],
|
||||
"rating": [2.0, 3.0, 3.5],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, rating_threshold=4.0)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No interactions found with rating >= 4.0"
|
||||
):
|
||||
ar._prepare_transactions()
|
||||
|
||||
def test_prepare_transactions_string_conversion(self):
|
||||
"""Test that movie IDs are converted to strings."""
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1],
|
||||
"itemId": [123, 456], # Numeric movie IDs
|
||||
"rating": [4.0, 5.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
transactions = ar._prepare_transactions()
|
||||
|
||||
assert transactions == [["123", "456"]]
|
||||
assert all(
|
||||
isinstance(item, str)
|
||||
for transaction in transactions
|
||||
for item in transaction
|
||||
)
|
||||
|
||||
|
||||
class TestAssociationRulesMining:
|
||||
"""Test class for frequent itemsets mining methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
@patch("pygrex.utils.association_rules.TransactionEncoder")
|
||||
def test_mine_frequent_itemsets_success(self, mock_encoder_class, mock_fpgrowth):
|
||||
"""Test successful frequent itemsets mining."""
|
||||
# Mock TransactionEncoder
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder_class.return_value = mock_encoder
|
||||
mock_encoder.fit_transform.return_value = np.array(
|
||||
[[True, False], [False, True]]
|
||||
)
|
||||
mock_encoder.columns_ = ["A", "B"]
|
||||
|
||||
# Mock fpgrowth result
|
||||
mock_frequent_itemsets = pd.DataFrame(
|
||||
{"support": [0.3, 0.4], "itemsets": [{"A"}, {"B"}]}
|
||||
)
|
||||
mock_fpgrowth.return_value = mock_frequent_itemsets
|
||||
|
||||
dataset = pd.DataFrame(
|
||||
{"userId": [1, 2], "itemId": ["A", "B"], "rating": [4.0, 5.0]}
|
||||
)
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, min_support=0.2)
|
||||
transactions = [["A"], ["B"]]
|
||||
|
||||
result = ar._mine_frequent_itemsets(transactions) # type: ignore
|
||||
|
||||
assert not result.empty
|
||||
mock_fpgrowth.assert_called_once()
|
||||
mock_encoder.fit_transform.assert_called_once_with(transactions)
|
||||
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
@patch("pygrex.utils.association_rules.TransactionEncoder")
|
||||
def test_mine_frequent_itemsets_empty_result(
|
||||
self, mock_encoder_class, mock_fpgrowth
|
||||
):
|
||||
"""Test frequent itemsets mining with empty result."""
|
||||
# Mock TransactionEncoder
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder_class.return_value = mock_encoder
|
||||
mock_encoder.fit_transform.return_value = np.array([[True, False]])
|
||||
mock_encoder.columns_ = ["A", "B"]
|
||||
|
||||
# Mock empty fpgrowth result
|
||||
mock_fpgrowth.return_value = pd.DataFrame()
|
||||
|
||||
dataset = pd.DataFrame({"userId": [1], "itemId": ["A"], "rating": [4.0]})
|
||||
self.mock_data_reader.dataset = dataset
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader, min_support=0.9)
|
||||
transactions = [["A"]]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No frequent itemsets found with min_support=0.9"
|
||||
):
|
||||
ar._mine_frequent_itemsets(transactions) # type: ignore
|
||||
|
||||
|
||||
class TestAssociationRulesGeneration:
|
||||
"""Test class for association rules generation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.sample_frequent_itemsets = pd.DataFrame(
|
||||
{"support": [0.3, 0.4, 0.2], "itemsets": [{"A"}, {"B"}, {"A", "B"}]}
|
||||
)
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
def test_generate_association_rules_success(self, mock_association_rules):
|
||||
"""Test successful association rules generation."""
|
||||
mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [{"A"}],
|
||||
"consequents": [{"B"}],
|
||||
"confidence": [0.8],
|
||||
"support": [0.2],
|
||||
}
|
||||
)
|
||||
mock_association_rules.return_value = mock_rules
|
||||
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
ar = AssociationRules(mock_data_reader, min_confidence=0.5)
|
||||
|
||||
result = ar._generate_association_rules(self.sample_frequent_itemsets)
|
||||
|
||||
assert not result.empty
|
||||
mock_association_rules.assert_called_once_with(
|
||||
self.sample_frequent_itemsets, metric="confidence", min_threshold=0.5
|
||||
)
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
def test_generate_association_rules_empty_result(self, mock_association_rules):
|
||||
"""Test association rules generation with empty result."""
|
||||
mock_association_rules.return_value = pd.DataFrame()
|
||||
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
ar = AssociationRules(mock_data_reader, min_confidence=0.9)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No association rules found with min_confidence=0.9"
|
||||
):
|
||||
ar._generate_association_rules(self.sample_frequent_itemsets)
|
||||
|
||||
|
||||
class TestAssociationRulesCompute:
|
||||
"""Test class for the main compute method."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.sample_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "A", "C", "B", "C"],
|
||||
"rating": [4.0, 5.0, 4.5, 4.0, 4.5, 5.0],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.sample_dataset
|
||||
|
||||
def test_compute_empty_dataset(self):
|
||||
"""Test compute method with empty dataset."""
|
||||
self.mock_data_reader.dataset = pd.DataFrame()
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset is empty"):
|
||||
ar.compute()
|
||||
|
||||
@patch.object(AssociationRules, "_generate_association_rules")
|
||||
@patch.object(AssociationRules, "_mine_frequent_itemsets")
|
||||
@patch.object(AssociationRules, "_prepare_transactions")
|
||||
def test_compute_success(self, mock_prepare, mock_mine, mock_generate):
|
||||
"""Test successful compute execution."""
|
||||
# Mock method returns
|
||||
mock_prepare.return_value = [["A", "B"], ["A", "C"]]
|
||||
mock_frequent_itemsets = pd.DataFrame({"support": [0.3], "itemsets": [{"A"}]})
|
||||
mock_mine.return_value = mock_frequent_itemsets
|
||||
mock_rules = pd.DataFrame(
|
||||
{"antecedents": [{"A"}], "consequents": [{"B"}], "confidence": [0.8]}
|
||||
)
|
||||
mock_generate.return_value = mock_rules
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
result = ar.compute()
|
||||
|
||||
assert result.equals(mock_rules)
|
||||
assert ar._frequent_itemsets.equals(mock_frequent_itemsets) # type: ignore
|
||||
assert ar._association_rules.equals(mock_rules) # type: ignore
|
||||
|
||||
mock_prepare.assert_called_once()
|
||||
mock_mine.assert_called_once_with([["A", "B"], ["A", "C"]])
|
||||
mock_generate.assert_called_once_with(mock_frequent_itemsets)
|
||||
|
||||
@patch.object(AssociationRules, "_prepare_transactions")
|
||||
def test_compute_no_transactions(self, mock_prepare):
|
||||
"""Test compute method when no transactions are found."""
|
||||
mock_prepare.return_value = []
|
||||
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
with pytest.raises(ValueError, match="No transactions found after filtering"):
|
||||
ar.compute()
|
||||
|
||||
|
||||
class TestAssociationRulesAccessors:
|
||||
"""Test class for accessor methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
def test_get_frequent_itemsets_before_compute(self):
|
||||
"""Test getting frequent itemsets before compute is called."""
|
||||
result = self.ar.get_frequent_itemsets()
|
||||
assert result is None
|
||||
|
||||
def test_get_frequent_itemsets_after_compute(self):
|
||||
"""Test getting frequent itemsets after compute is called."""
|
||||
mock_itemsets = pd.DataFrame({"support": [0.3], "itemsets": [{"A"}]})
|
||||
self.ar._frequent_itemsets = mock_itemsets
|
||||
|
||||
result = self.ar.get_frequent_itemsets()
|
||||
assert result.equals(mock_itemsets) # type: ignore
|
||||
|
||||
|
||||
class TestAssociationRulesRecommendations:
|
||||
"""Test class for recommendation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
# Mock association rules
|
||||
self.mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [
|
||||
frozenset(["A"]),
|
||||
frozenset(["B"]),
|
||||
frozenset(["A", "B"]),
|
||||
],
|
||||
"consequents": [frozenset(["B"]), frozenset(["C"]), frozenset(["C"])],
|
||||
"confidence": [0.8, 0.6, 0.9],
|
||||
"lift": [1.2, 1.1, 1.5],
|
||||
"support": [0.4, 0.3, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
def test_get_recommendations_before_compute(self):
|
||||
"""Test getting recommendations before compute is called."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Must call compute\\(\\) before getting recommendations"
|
||||
):
|
||||
self.ar.get_recommendations_for_items(["A"])
|
||||
|
||||
def test_get_recommendations_empty_items(self):
|
||||
"""Test getting recommendations with empty items list."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
with pytest.raises(ValueError, match="Items list cannot be empty"):
|
||||
self.ar.get_recommendations_for_items([])
|
||||
|
||||
def test_get_recommendations_success(self):
|
||||
"""Test successful recommendation generation."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["A"], top_k=5)
|
||||
|
||||
assert not result.empty
|
||||
assert len(result) <= 5
|
||||
assert "confidence" in result.columns
|
||||
assert "antecedents" in result.columns
|
||||
assert "consequents" in result.columns
|
||||
|
||||
def test_get_recommendations_no_matching_rules(self):
|
||||
"""Test recommendations when no rules match the items."""
|
||||
self.ar._association_rules = self.mock_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["Z"]) # Item not in rules
|
||||
|
||||
assert result.empty
|
||||
|
||||
def test_get_recommendations_top_k_limit(self):
|
||||
"""Test that recommendations respect top_k limit."""
|
||||
# Create more rules than top_k
|
||||
extended_rules = pd.concat([self.mock_rules] * 5, ignore_index=True)
|
||||
self.ar._association_rules = extended_rules
|
||||
|
||||
result = self.ar.get_recommendations_for_items(["A"], top_k=2)
|
||||
|
||||
assert len(result) <= 2
|
||||
|
||||
|
||||
class TestAssociationRulesStringRepresentations:
|
||||
"""Test class for string representation methods."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
self.mock_data_reader.dataset = pd.DataFrame(
|
||||
{"userId": [1], "itemId": ["A"], "rating": [4.0]}
|
||||
)
|
||||
|
||||
def test_str_representation(self):
|
||||
"""Test string representation of AssociationRules object."""
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.1,
|
||||
min_confidence=0.3,
|
||||
rating_threshold=3.5,
|
||||
)
|
||||
|
||||
expected = "AssociationRules(min_support=0.1, min_confidence=0.3, rating_threshold=3.5)"
|
||||
assert str(ar) == expected
|
||||
|
||||
def test_repr_representation(self):
|
||||
"""Test repr representation of AssociationRules object."""
|
||||
ar = AssociationRules(self.mock_data_reader)
|
||||
|
||||
expected = "AssociationRules(min_support=0.2, min_confidence=0.2, rating_threshold=4.0)"
|
||||
assert repr(ar) == expected
|
||||
|
||||
|
||||
class TestAssociationRulesIntegration:
|
||||
"""Integration tests for AssociationRules class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures for integration tests."""
|
||||
self.mock_data_reader = Mock(spec=DataReader)
|
||||
|
||||
# Create a more comprehensive dataset for integration testing
|
||||
self.integration_dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5],
|
||||
"itemId": [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"A",
|
||||
"B",
|
||||
"D",
|
||||
"B",
|
||||
"C",
|
||||
"A",
|
||||
"C",
|
||||
"D",
|
||||
"B",
|
||||
"D",
|
||||
],
|
||||
"rating": [
|
||||
4.5,
|
||||
4.0,
|
||||
5.0,
|
||||
4.0,
|
||||
4.5,
|
||||
3.5,
|
||||
4.0,
|
||||
4.5,
|
||||
5.0,
|
||||
4.0,
|
||||
4.5,
|
||||
4.0,
|
||||
4.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
self.mock_data_reader.dataset = self.integration_dataset
|
||||
|
||||
@patch("pygrex.utils.association_rules.association_rules")
|
||||
@patch("pygrex.utils.association_rules.fpgrowth")
|
||||
def test_full_workflow_integration(self, mock_fpgrowth, mock_association_rules):
|
||||
"""Test the complete workflow from initialization to recommendations."""
|
||||
# Mock fpgrowth result
|
||||
mock_frequent_itemsets = pd.DataFrame(
|
||||
{
|
||||
"support": [0.4, 0.6, 0.3],
|
||||
"itemsets": [frozenset(["A"]), frozenset(["B"]), frozenset(["A", "B"])],
|
||||
}
|
||||
)
|
||||
mock_fpgrowth.return_value = mock_frequent_itemsets
|
||||
|
||||
# Mock association rules result
|
||||
mock_rules = pd.DataFrame(
|
||||
{
|
||||
"antecedents": [frozenset(["A"])],
|
||||
"consequents": [frozenset(["B"])],
|
||||
"confidence": [0.8],
|
||||
"lift": [1.2],
|
||||
"support": [0.3],
|
||||
}
|
||||
)
|
||||
mock_association_rules.return_value = mock_rules
|
||||
|
||||
# Initialize and compute
|
||||
ar = AssociationRules(
|
||||
self.mock_data_reader,
|
||||
min_support=0.2,
|
||||
min_confidence=0.5,
|
||||
rating_threshold=4.0,
|
||||
)
|
||||
|
||||
# Run compute
|
||||
rules = ar.compute()
|
||||
|
||||
# Verify results
|
||||
assert not rules.empty
|
||||
assert ar.get_frequent_itemsets() is not None
|
||||
|
||||
# Test recommendations
|
||||
recommendations = ar.get_recommendations_for_items(["A"])
|
||||
assert isinstance(recommendations, pd.DataFrame)
|
||||
|
||||
|
||||
# Pytest configuration and fixtures
|
||||
@pytest.fixture
|
||||
def sample_data_reader():
|
||||
"""Fixture providing a sample DataReader for tests."""
|
||||
mock_data_reader = Mock(spec=DataReader)
|
||||
mock_data_reader.dataset = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 1, 2, 2, 3, 3],
|
||||
"itemId": ["A", "B", "A", "C", "B", "C"],
|
||||
"rating": [4.0, 5.0, 4.5, 4.0, 4.5, 5.0],
|
||||
}
|
||||
)
|
||||
return mock_data_reader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def association_rules_instance(sample_data_reader):
|
||||
"""Fixture providing an AssociationRules instance for tests."""
|
||||
return AssociationRules(sample_data_reader)
|
||||
|
||||
|
||||
# Parametrized tests
|
||||
@pytest.mark.parametrize(
|
||||
"support,confidence,threshold,should_raise",
|
||||
[
|
||||
(0.1, 0.2, 4.0, False),
|
||||
(0.0, 0.2, 4.0, True),
|
||||
(1.1, 0.2, 4.0, True),
|
||||
(0.1, 0.0, 4.0, True),
|
||||
(0.1, 1.5, 4.0, True),
|
||||
(0.1, 0.2, -1.0, True),
|
||||
],
|
||||
)
|
||||
def test_parameter_validation_parametrized(
|
||||
sample_data_reader, support, confidence, threshold, should_raise
|
||||
):
|
||||
"""Parametrized test for parameter validation."""
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
AssociationRules(sample_data_reader, support, confidence, threshold)
|
||||
else:
|
||||
ar = AssociationRules(sample_data_reader, support, confidence, threshold)
|
||||
assert ar.min_support == support
|
||||
assert ar.min_confidence == confidence
|
||||
assert ar.rating_threshold == threshold
|
||||
@@ -0,0 +1,135 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from pygrex.utils.scale import Scale
|
||||
|
||||
|
||||
class TestScale:
|
||||
"""Test suite for the Scale class."""
|
||||
|
||||
def test_quantile_basic(self):
|
||||
"""Test basic functionality of the quantile method."""
|
||||
raw_values = [3.0, 1.0, 4.0, 2.0, 6.0]
|
||||
scaled = Scale.quantile(raw_values)
|
||||
# With 5 elements, we expect ranks to map evenly across 1-5
|
||||
assert np.allclose(scaled, [3, 1, 4, 2, 5])
|
||||
|
||||
def test_quantile_empty(self):
|
||||
"""Test quantile method with empty input."""
|
||||
with pytest.raises(ValueError, match="Raw predictions array is empty"):
|
||||
Scale.quantile([])
|
||||
|
||||
def test_quantile_single_element(self):
|
||||
"""Test quantile method with a single element."""
|
||||
result = Scale.quantile([7.5])
|
||||
assert np.allclose(result, [3]) # Middle of the default range [1, 5]
|
||||
|
||||
def test_quantile_custom_range(self):
|
||||
"""Test quantile method with custom target range."""
|
||||
raw_values = [10, 20, 30, 40, 50]
|
||||
scaled = Scale.quantile(raw_values, target_min=0, target_max=1)
|
||||
# Should map to [0, 0.25, 0.5, 0.75, 1]
|
||||
assert np.allclose(scaled, [0, 0.25, 0.5, 0.75, 1])
|
||||
|
||||
def test_quantile_equal_values(self):
|
||||
"""Test quantile method with all equal values."""
|
||||
raw_values = [5, 5, 5, 5]
|
||||
scaled = Scale.quantile(raw_values)
|
||||
# All values should get the average rank (2.5), which maps to the middle of [1, 5]
|
||||
assert np.allclose(scaled, [3, 3, 3, 3])
|
||||
|
||||
def test_quantile_numpy_input(self):
|
||||
"""Test quantile method with numpy array input."""
|
||||
raw_values = np.array([3.0, 1.0, 4.0, 2.0, 6.0])
|
||||
scaled = Scale.quantile(raw_values)
|
||||
assert np.allclose(scaled, [3, 1, 4, 2, 5])
|
||||
|
||||
def test_linear_basic(self):
|
||||
"""Test basic functionality of the linear method."""
|
||||
raw_values = [2.0, 4.0, 6.0, 8.0, 10.0]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# Should map linearly from [2, 10] to [1, 5]
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
|
||||
def test_linear_empty(self):
|
||||
"""Test linear method with empty input."""
|
||||
with pytest.raises(ValueError, match="Raw predictions array is empty"):
|
||||
Scale.linear([])
|
||||
|
||||
def test_linear_single_element(self):
|
||||
"""Test linear method with a single element."""
|
||||
result = Scale.linear([7.5])
|
||||
assert np.allclose(result, [3]) # Middle of the default range [1, 5]
|
||||
|
||||
def test_linear_custom_range(self):
|
||||
"""Test linear method with custom target range."""
|
||||
raw_values = [0, 5, 10]
|
||||
scaled = Scale.linear(raw_values, target_min=0, target_max=100)
|
||||
assert np.allclose(scaled, [0, 50, 100])
|
||||
|
||||
def test_linear_custom_ref_range(self):
|
||||
"""Test linear method with custom reference range."""
|
||||
raw_values = [2, 5, 8]
|
||||
scaled = Scale.linear(raw_values, ref_min=0, ref_max=10)
|
||||
# Should map from [0, 10] to [1, 5] regardless of actual min/max
|
||||
expected = [1 + (2 / 10) * 4, 1 + (5 / 10) * 4, 1 + (8 / 10) * 4]
|
||||
assert np.allclose(scaled, expected)
|
||||
|
||||
def test_linear_with_outliers(self):
|
||||
"""Test linear method with outlier handling."""
|
||||
# Values with outliers
|
||||
raw_values = [5, 6, 7, 8, 20] # 20 is an outlier
|
||||
|
||||
# With outlier handling (default)
|
||||
scaled_with_handling = Scale.linear(raw_values)
|
||||
|
||||
# Without outlier handling
|
||||
scaled_without_handling = Scale.linear(raw_values, handle_outliers=False)
|
||||
|
||||
# The result with outlier handling should be different
|
||||
assert not np.allclose(scaled_with_handling, scaled_without_handling)
|
||||
|
||||
# The outlier should be scaled to the max value (5) without handling
|
||||
assert scaled_without_handling[-1] == 5
|
||||
|
||||
# With handling, the outlier should still be clamped to the max
|
||||
assert scaled_with_handling[-1] == 5
|
||||
|
||||
# But other values should be more spread out with handling
|
||||
assert np.max(scaled_with_handling[:-1]) > np.max(scaled_without_handling[:-1])
|
||||
|
||||
def test_linear_equal_values(self):
|
||||
"""Test linear method with all equal values."""
|
||||
raw_values = [7, 7, 7, 7]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# All equal values should map to the middle of target range
|
||||
assert np.allclose(scaled, [3, 3, 3, 3])
|
||||
|
||||
def test_linear_equal_ref_bounds(self):
|
||||
"""Test linear method with equal reference bounds."""
|
||||
raw_values = [5, 6, 7]
|
||||
scaled = Scale.linear(raw_values, ref_min=5, ref_max=5)
|
||||
# When ref bounds are equal, should map to middle of target range
|
||||
assert np.allclose(scaled, [3, 3, 3])
|
||||
|
||||
def test_linear_numpy_input(self):
|
||||
"""Test linear method with numpy array input."""
|
||||
raw_values = np.array([2.0, 4.0, 6.0, 8.0, 10.0])
|
||||
scaled = Scale.linear(raw_values)
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
|
||||
def test_linear_clipping(self):
|
||||
"""Test that linear scaling properly clips out-of-bounds values."""
|
||||
# Values outside the reference range
|
||||
raw_values = [0, 5, 10, 15]
|
||||
scaled = Scale.linear(raw_values, ref_min=5, ref_max=10)
|
||||
# Values below ref_min should be clipped to target_min
|
||||
assert scaled[0] == 1
|
||||
# Values above ref_max should be clipped to target_max
|
||||
assert scaled[3] == 5
|
||||
|
||||
def test_linear_negative_values(self):
|
||||
"""Test linear scaling with negative values."""
|
||||
raw_values = [-10, -5, 0, 5, 10]
|
||||
scaled = Scale.linear(raw_values)
|
||||
# Should map [-10, 10] to [1, 5]
|
||||
assert np.allclose(scaled, [1, 2, 3, 4, 5])
|
||||
@@ -0,0 +1,406 @@
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch
|
||||
from typing import Dict, List
|
||||
from pygrex.utils.aggregation_strategy import (
|
||||
ScoreAggregator,
|
||||
AggregationStrategy, # type: ignore
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import TypeAlias
|
||||
|
||||
|
||||
UserID: TypeAlias = str
|
||||
ItemID: TypeAlias = str
|
||||
EvaluationScore: TypeAlias = float
|
||||
AggregatedScore: TypeAlias = float
|
||||
UserEvaluations: TypeAlias = Dict[UserID, Dict[ItemID, EvaluationScore]]
|
||||
UserRankings: TypeAlias = Dict[UserID, List[ItemID]]
|
||||
AggregatedScores: TypeAlias = Dict[ItemID, AggregatedScore]
|
||||
|
||||
|
||||
class TestScoreAggregatorUnittest(unittest.TestCase):
|
||||
"""Unit tests using unittest framework."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
self.aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
self.sample_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 4.5, "item_B": 3.0, "item_C": 5.0},
|
||||
"user2": {"item_A": 3.0, "item_B": 4.0, "item_C": 2.0},
|
||||
"user3": {"item_A": 4.0, "item_B": 2.0, "item_C": 3.0},
|
||||
}
|
||||
|
||||
self.sample_rankings: UserRankings = {
|
||||
"user1": ["item_C", "item_A", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_A", "item_C", "item_B"],
|
||||
}
|
||||
|
||||
self.empty_evaluations: UserEvaluations = {}
|
||||
|
||||
self.single_user_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 3.5, "item_B": 4.0}
|
||||
}
|
||||
|
||||
def test_init_with_mrp(self):
|
||||
"""Test initialization with most respected person."""
|
||||
aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
self.assertEqual(aggregator.most_respected_person, "user1")
|
||||
|
||||
def test_init_without_mrp(self):
|
||||
"""Test initialization without most respected person."""
|
||||
aggregator = ScoreAggregator()
|
||||
self.assertIsNone(aggregator.most_respected_person)
|
||||
|
||||
def test_empty_evaluations(self):
|
||||
"""Test aggregation with empty evaluations."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.empty_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_avg_predictions(self):
|
||||
"""Test average predictions aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: (4.5+3.0+4.0)/3 = 3.833..., item_B: (3.0+4.0+2.0)/3 = 3.0, item_C: (5.0+2.0+3.0)/3 = 3.333...
|
||||
self.assertAlmostEqual(result["item_A"], 3.833333333333333, places=5)
|
||||
self.assertAlmostEqual(result["item_B"], 3.0, places=5)
|
||||
self.assertAlmostEqual(result["item_C"], 3.333333333333333, places=5)
|
||||
|
||||
def test_least_misery(self):
|
||||
"""Test least misery aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.LEAST_MISERY, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: min(4.5, 3.0, 4.0) = 3.0, item_B: min(3.0, 4.0, 2.0) = 2.0, item_C: min(5.0, 2.0, 3.0) = 2.0
|
||||
self.assertEqual(result["item_A"], 3.0)
|
||||
self.assertEqual(result["item_B"], 2.0)
|
||||
self.assertEqual(result["item_C"], 2.0)
|
||||
|
||||
def test_most_pleasure(self):
|
||||
"""Test most pleasure aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_PLEASURE, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: max(4.5, 3.0, 4.0) = 4.5, item_B: max(3.0, 4.0, 2.0) = 4.0, item_C: max(5.0, 2.0, 3.0) = 5.0
|
||||
self.assertEqual(result["item_A"], 4.5)
|
||||
self.assertEqual(result["item_B"], 4.0)
|
||||
self.assertEqual(result["item_C"], 5.0)
|
||||
|
||||
def test_most_respected_person(self):
|
||||
"""Test most respected person aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: user1's evaluations: item_A: 4.5, item_B: 3.0, item_C: 5.0
|
||||
self.assertEqual(result["item_A"], 4.5)
|
||||
self.assertEqual(result["item_B"], 3.0)
|
||||
self.assertEqual(result["item_C"], 5.0)
|
||||
|
||||
def test_mrp_without_mrp_set(self):
|
||||
"""Test MRP strategy without setting most respected person."""
|
||||
aggregator = ScoreAggregator() # No MRP set
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Most respected person not specified", str(context.exception))
|
||||
|
||||
def test_mrp_user_not_in_evaluations(self):
|
||||
"""Test MRP strategy when MRP user is not in evaluations."""
|
||||
aggregator = ScoreAggregator(most_respected_person="nonexistent_user")
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("not found in evaluations", str(context.exception))
|
||||
|
||||
def test_additive_utilitarian(self):
|
||||
"""Test additive utilitarian aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.ADDITIVE_UTILITARIAN, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: 4.5+3.0+4.0 = 11.5, item_B: 3.0+4.0+2.0 = 9.0, item_C: 5.0+2.0+3.0 = 10.0
|
||||
self.assertEqual(result["item_A"], 11.5)
|
||||
self.assertEqual(result["item_B"], 9.0)
|
||||
self.assertEqual(result["item_C"], 10.0)
|
||||
|
||||
def test_multiplicative(self):
|
||||
"""Test multiplicative aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MULTIPLICATIVE, # type: ignore
|
||||
)
|
||||
|
||||
# Expected: item_A: 4.5*3.0*4.0 = 54.0, item_B: 3.0*4.0*2.0 = 24.0, item_C: 5.0*2.0*3.0 = 30.0
|
||||
self.assertEqual(result["item_A"], 54.0)
|
||||
self.assertEqual(result["item_B"], 24.0)
|
||||
self.assertEqual(result["item_C"], 30.0)
|
||||
|
||||
def test_borda_count(self):
|
||||
"""Test Borda count aggregation."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
self.sample_rankings, # type: ignore
|
||||
)
|
||||
|
||||
# Expected scores based on rankings:
|
||||
# item_A: user1(1) + user2(1) + user3(2) = 4
|
||||
# item_B: user1(0) + user2(2) + user3(0) = 2
|
||||
# item_C: user1(2) + user2(0) + user3(1) = 3
|
||||
self.assertEqual(result["item_A"], 4.0)
|
||||
self.assertEqual(result["item_B"], 2.0)
|
||||
self.assertEqual(result["item_C"], 3.0)
|
||||
|
||||
def test_borda_count_without_rankings(self):
|
||||
"""Test Borda count without providing rankings."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT, # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Rankings required for Borda Count", str(context.exception))
|
||||
|
||||
def test_unknown_strategy(self):
|
||||
"""Test with unknown aggregation strategy."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.aggregator.aggregate_scores(
|
||||
self.sample_evaluations, # type: ignore
|
||||
"invalid_strategy", # type: ignore
|
||||
)
|
||||
|
||||
self.assertIn("Unknown aggregation strategy", str(context.exception))
|
||||
|
||||
def test_get_top_recommendation(self):
|
||||
"""Test getting top recommendation."""
|
||||
top_item = self.aggregator.get_top_recommendation(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.MOST_PLEASURE, # type: ignore
|
||||
)
|
||||
|
||||
# item_C has the highest max value (5.0)
|
||||
self.assertEqual(top_item, "item_C")
|
||||
|
||||
def test_get_top_recommendation_with_rankings(self):
|
||||
"""Test getting top recommendation with Borda count."""
|
||||
top_item = self.aggregator.get_top_recommendation(
|
||||
self.sample_evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
self.sample_rankings, # type: ignore
|
||||
)
|
||||
|
||||
# item_A has the highest Borda count (4.0)
|
||||
self.assertEqual(top_item, "item_A")
|
||||
|
||||
def test_single_user_evaluation(self):
|
||||
"""Test aggregation with single user."""
|
||||
result = self.aggregator.aggregate_scores(
|
||||
self.single_user_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# With single user, average should equal the original values
|
||||
self.assertEqual(result["item_A"], 3.5)
|
||||
self.assertEqual(result["item_B"], 4.0)
|
||||
|
||||
def test_missing_items_in_evaluations(self):
|
||||
"""Test with missing items in some user evaluations."""
|
||||
incomplete_evaluations: UserEvaluations = {
|
||||
"user1": {"item_A": 4.0, "item_B": 3.0},
|
||||
"user2": {"item_A": 3.0, "item_C": 2.0}, # Missing item_B
|
||||
"user3": {"item_B": 2.0, "item_C": 3.0}, # Missing item_A
|
||||
}
|
||||
|
||||
result = self.aggregator.aggregate_scores(
|
||||
incomplete_evaluations, # type: ignore
|
||||
AggregationStrategy.AVG_PREDICTIONS, # type: ignore
|
||||
)
|
||||
|
||||
# Should handle missing items gracefully
|
||||
self.assertAlmostEqual(result["item_A"], 3.5, places=5) # (4.0 + 3.0) / 2
|
||||
self.assertAlmostEqual(result["item_B"], 2.5, places=5) # (3.0 + 2.0) / 2
|
||||
self.assertAlmostEqual(result["item_C"], 2.5, places=5) # (2.0 + 3.0) / 2
|
||||
|
||||
|
||||
class TestScoreAggregatorPytest:
|
||||
"""Unit tests using pytest framework."""
|
||||
|
||||
@pytest.fixture
|
||||
def aggregator(self):
|
||||
"""Fixture for ScoreAggregator instance."""
|
||||
return ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_evaluations(self):
|
||||
"""Fixture for sample evaluations."""
|
||||
return {
|
||||
"user1": {"item_A": 4.5, "item_B": 3.0, "item_C": 5.0},
|
||||
"user2": {"item_A": 3.0, "item_B": 4.0, "item_C": 2.0},
|
||||
"user3": {"item_A": 4.0, "item_B": 2.0, "item_C": 3.0},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_rankings(self):
|
||||
"""Fixture for sample rankings."""
|
||||
return {
|
||||
"user1": ["item_C", "item_A", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_A", "item_C", "item_B"],
|
||||
}
|
||||
|
||||
def test_avg_predictions_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test average predictions using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.AVG_PREDICTIONS
|
||||
)
|
||||
|
||||
assert abs(result["item_A"] - 3.833333333333333) < 1e-5
|
||||
assert abs(result["item_B"] - 3.0) < 1e-5
|
||||
assert abs(result["item_C"] - 3.333333333333333) < 1e-5
|
||||
|
||||
def test_least_misery_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test least misery using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.LEAST_MISERY
|
||||
)
|
||||
|
||||
assert result["item_A"] == 3.0
|
||||
assert result["item_B"] == 2.0
|
||||
assert result["item_C"] == 2.0
|
||||
|
||||
def test_borda_count_pytest(self, aggregator, sample_evaluations, sample_rankings):
|
||||
"""Test Borda count using pytest."""
|
||||
result = aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.BORDA_COUNT, sample_rankings
|
||||
)
|
||||
|
||||
assert result["item_A"] == 4.0
|
||||
assert result["item_B"] == 2.0
|
||||
assert result["item_C"] == 3.0
|
||||
|
||||
def test_mrp_error_pytest(self, sample_evaluations):
|
||||
"""Test MRP error handling using pytest."""
|
||||
aggregator = ScoreAggregator() # No MRP set
|
||||
|
||||
with pytest.raises(ValueError, match="Most respected person not specified"):
|
||||
aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.MOST_RESPECTED_PERSON
|
||||
)
|
||||
|
||||
def test_borda_count_error_pytest(self, aggregator, sample_evaluations):
|
||||
"""Test Borda count error handling using pytest."""
|
||||
with pytest.raises(ValueError, match="Rankings required for Borda Count"):
|
||||
aggregator.aggregate_scores(
|
||||
sample_evaluations, AggregationStrategy.BORDA_COUNT
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strategy,expected_top",
|
||||
[
|
||||
(AggregationStrategy.LEAST_MISERY, "item_A"),
|
||||
(AggregationStrategy.MOST_PLEASURE, "item_C"),
|
||||
(AggregationStrategy.ADDITIVE_UTILITARIAN, "item_A"),
|
||||
],
|
||||
)
|
||||
def test_top_recommendation_parametrized(
|
||||
self, aggregator, sample_evaluations, strategy, expected_top
|
||||
):
|
||||
"""Test top recommendations with parametrized testing."""
|
||||
top_item = aggregator.get_top_recommendation(sample_evaluations, strategy)
|
||||
assert top_item == expected_top
|
||||
|
||||
def test_empty_evaluations_pytest(self, aggregator):
|
||||
"""Test empty evaluations using pytest."""
|
||||
result = aggregator.aggregate_scores({}, AggregationStrategy.AVG_PREDICTIONS)
|
||||
assert result == {}
|
||||
|
||||
def test_multiplicative_with_zero(self, aggregator):
|
||||
"""Test multiplicative aggregation with zero values."""
|
||||
evaluations_with_zero = {
|
||||
"user1": {"item_A": 0.0, "item_B": 3.0},
|
||||
"user2": {"item_A": 4.0, "item_B": 2.0},
|
||||
}
|
||||
|
||||
result = aggregator.aggregate_scores(
|
||||
evaluations_with_zero, AggregationStrategy.MULTIPLICATIVE
|
||||
)
|
||||
|
||||
assert result["item_A"] == 0.0 # 0.0 * 4.0 = 0.0
|
||||
assert result["item_B"] == 6.0 # 3.0 * 2.0 = 6.0
|
||||
|
||||
|
||||
# Integration tests
|
||||
class TestScoreAggregatorIntegration(unittest.TestCase):
|
||||
"""Integration tests for ScoreAggregator."""
|
||||
|
||||
def test_all_strategies_consistency(self):
|
||||
"""Test that all strategies produce consistent results."""
|
||||
evaluations = {
|
||||
"user1": {"item_A": 5.0, "item_B": 3.0, "item_C": 4.0},
|
||||
"user2": {"item_A": 4.0, "item_B": 5.0, "item_C": 3.0},
|
||||
"user3": {"item_A": 3.0, "item_B": 4.0, "item_C": 5.0},
|
||||
}
|
||||
|
||||
rankings = {
|
||||
"user1": ["item_A", "item_C", "item_B"],
|
||||
"user2": ["item_B", "item_A", "item_C"],
|
||||
"user3": ["item_C", "item_B", "item_A"],
|
||||
}
|
||||
|
||||
aggregator = ScoreAggregator(most_respected_person="user1")
|
||||
|
||||
# Test that all strategies return valid results
|
||||
strategies_without_rankings = [
|
||||
AggregationStrategy.AVG_PREDICTIONS,
|
||||
AggregationStrategy.LEAST_MISERY,
|
||||
AggregationStrategy.MOST_PLEASURE,
|
||||
AggregationStrategy.MOST_RESPECTED_PERSON,
|
||||
AggregationStrategy.AVG_PREDICTIONS,
|
||||
AggregationStrategy.ADDITIVE_UTILITARIAN,
|
||||
AggregationStrategy.MULTIPLICATIVE,
|
||||
]
|
||||
|
||||
for strategy in strategies_without_rankings:
|
||||
result = aggregator.aggregate_scores(evaluations, strategy) # type: ignore
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertEqual(len(result), 3) # Should have 3 items
|
||||
for score in result.values():
|
||||
self.assertIsInstance(score, (int, float))
|
||||
|
||||
# Test Borda count separately
|
||||
borda_result = aggregator.aggregate_scores(
|
||||
evaluations, # type: ignore
|
||||
AggregationStrategy.BORDA_COUNT,
|
||||
rankings, # type: ignore
|
||||
)
|
||||
self.assertIsInstance(borda_result, dict)
|
||||
self.assertEqual(len(borda_result), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run unittest tests
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
from typing import List, Any
|
||||
|
||||
from pygrex.utils.sliding_window import SlidingWindow
|
||||
|
||||
|
||||
class TestSlidingWindow:
|
||||
"""Test suite for the SlidingWindow class."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test the basic window sliding functionality."""
|
||||
data = [1, 2, 3, 4, 5]
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# First window
|
||||
assert sliding_window.get_next_window() == [1, 2, 3]
|
||||
# Second window
|
||||
assert sliding_window.get_next_window() == [2, 3, 4]
|
||||
# Third window
|
||||
assert sliding_window.get_next_window() == [3, 4, 5]
|
||||
# No more windows
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_window_equal_to_sequence_length(self):
|
||||
"""Test when window size equals the sequence length."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == [1, 2, 3]
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_window_larger_than_sequence(self):
|
||||
"""Test when window size is larger than the sequence length."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 4
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_empty_sequence(self):
|
||||
"""Test with an empty sequence."""
|
||||
data: List[Any] = []
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_invalid_window_size(self):
|
||||
"""Test with invalid window sizes."""
|
||||
data = [1, 2, 3, 4, 5]
|
||||
|
||||
# Test with zero window size
|
||||
with pytest.raises(ValueError):
|
||||
SlidingWindow(data, 0)
|
||||
|
||||
# Test with negative window size
|
||||
with pytest.raises(ValueError):
|
||||
SlidingWindow(data, -1)
|
||||
|
||||
def test_non_iterable_sequence(self):
|
||||
"""Test with a non-iterable object."""
|
||||
data = 123 # Integer is not iterable
|
||||
window_size = 2
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
SlidingWindow(data, window_size)
|
||||
|
||||
def test_reset_functionality(self):
|
||||
"""Test the reset functionality."""
|
||||
data = [1, 2, 3, 4]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# Get first two windows
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
assert sliding_window.get_next_window() == [2, 3]
|
||||
|
||||
# Reset and check if we get the first window again
|
||||
sliding_window.reset()
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
|
||||
def test_has_next(self):
|
||||
"""Test the has_next method."""
|
||||
data = [1, 2, 3]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.has_next() is True
|
||||
sliding_window.get_next_window() # Get first window
|
||||
assert sliding_window.has_next() is True
|
||||
sliding_window.get_next_window() # Get second window
|
||||
assert sliding_window.has_next() is False
|
||||
|
||||
def test_iterator_protocol(self):
|
||||
"""Test the iterator protocol implementation."""
|
||||
data = [1, 2, 3, 4]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
# Using the class in a for loop
|
||||
windows = []
|
||||
for window in sliding_window:
|
||||
windows.append(window)
|
||||
|
||||
assert windows == [[1, 2], [2, 3], [3, 4]]
|
||||
|
||||
# After iteration, the index should be at the end
|
||||
assert sliding_window.has_next() is False
|
||||
|
||||
# Test that reset works after iteration
|
||||
sliding_window.reset()
|
||||
assert sliding_window.has_next() is True
|
||||
assert sliding_window.get_next_window() == [1, 2]
|
||||
|
||||
def test_len_functionality(self):
|
||||
"""Test the __len__ method."""
|
||||
# Normal case
|
||||
data = [1, 2, 3, 4, 5]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 4
|
||||
|
||||
# Window size equals sequence length
|
||||
window_size = 5
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 1
|
||||
|
||||
# Window size greater than sequence length
|
||||
window_size = 6
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 0
|
||||
|
||||
# Empty sequence
|
||||
data = []
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
assert len(sliding_window) == 0
|
||||
|
||||
def test_with_string_data(self):
|
||||
"""Test with string data to verify generic implementation."""
|
||||
data = "abcde"
|
||||
window_size = 3
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == "abc"
|
||||
assert sliding_window.get_next_window() == "bcd"
|
||||
assert sliding_window.get_next_window() == "cde"
|
||||
assert sliding_window.get_next_window() is None
|
||||
|
||||
def test_multiple_data_types(self):
|
||||
"""Test with a list containing multiple data types."""
|
||||
data = [1, "two", 3.0, [4, 5], {"six": 6}]
|
||||
window_size = 2
|
||||
sliding_window = SlidingWindow(data, window_size)
|
||||
|
||||
assert sliding_window.get_next_window() == [1, "two"]
|
||||
assert sliding_window.get_next_window() == ["two", 3.0]
|
||||
assert sliding_window.get_next_window() == [3.0, [4, 5]]
|
||||
assert sliding_window.get_next_window() == [[4, 5], {"six": 6}]
|
||||
assert sliding_window.get_next_window() is None
|
||||
Reference in New Issue
Block a user