public code v1

This commit is contained in:
2026-05-22 10:02:10 +02:00
commit 46a9ecf065
166 changed files with 6982454 additions and 0 deletions
+202
View File
@@ -0,0 +1,202 @@
import pytest
import pandas as pd
import numpy as np
import tempfile
import os
from unittest.mock import patch
from pygrex.data_reader.data_reader import DataReader
@pytest.fixture
def valid_df():
"""Fixture providing a valid test DataFrame."""
return pd.DataFrame(
{
"userId": [1, 2, 3, 1],
"itemId": [100, 200, 300, 400],
"rating": [4.5, 3.0, 5.0, 2.0],
"timestamp": [1000, 2000, 3000, 4000],
}
)
@pytest.fixture
def test_csv_path(valid_df):
"""Fixture providing a temporary CSV file with valid data."""
with tempfile.TemporaryDirectory() as temp_dir:
test_csv_path = os.path.join(temp_dir, "test_data.csv")
valid_df.to_csv(test_csv_path, sep=",", index=False)
yield test_csv_path
def test_init_with_dataframe(valid_df):
"""Test initialization with a DataFrame."""
reader = DataReader(dataframe=valid_df)
pd.testing.assert_frame_equal(reader.dataset, valid_df)
assert reader.num_user == 3
assert reader.num_item == 4
def test_init_with_filepath(test_csv_path):
"""Test initialization with filepath."""
reader = DataReader(
filepath_or_buffer=test_csv_path,
sep=",",
names=["userId", "itemId", "rating", "timestamp"],
skiprows=1, # Skip header
)
# Check that dataset was loaded correctly
assert reader.num_user == 3
assert reader.num_item == 4
def test_dataset_loading_file_not_found():
"""Test exception when file is not found."""
with pytest.raises(FileNotFoundError):
DataReader(
filepath_or_buffer="nonexistent.csv",
sep=",",
names=["userId", "itemId", "rating", "timestamp"],
)
def test_dataset_validation_invalid_columns():
"""Test validation for invalid columns."""
# Create a DataFrame with non-numeric userId
invalid_df = pd.DataFrame(
{
"userId": ["user1", "user2", "user3"],
"itemId": [100, 200, 300],
"rating": [4.5, 3.0, 5.0],
"timestamp": [1000, 2000, 3000],
}
)
with pytest.raises(ValueError) as excinfo:
DataReader(dataframe=invalid_df)
assert "cannot be converted to a numeric type" in str(excinfo.value)
def test_dataset_validation_empty():
"""Test validation for empty DataFrame."""
empty_df = pd.DataFrame(columns=["userId", "itemId", "rating", "timestamp"])
with pytest.raises(ValueError) as excinfo:
DataReader(dataframe=empty_df)
assert "DataFrame cannot be empty" in str(excinfo.value)
def test_dataset_validation_invalid_types():
"""Test validation for invalid data types."""
# Create DataFrame with invalid rating (string)
invalid_types_df = pd.DataFrame(
{
"userId": [1, 2, 3],
"itemId": [100, 200, 300],
"rating": ["high", "medium", "low"],
"timestamp": [1000, 2000, 3000],
}
)
with pytest.raises(ValueError) as excinfo:
DataReader(dataframe=invalid_types_df)
assert "cannot be converted to a numeric type" in str(excinfo.value)
def test_dataset_validation_missing_values():
"""Test validation for missing values."""
# Create DataFrame with NaN values
missing_values_df = pd.DataFrame(
{
"userId": [1, 2, np.nan],
"itemId": [100, 200, 300],
"rating": [4.5, 3.0, 5.0],
"timestamp": [1000, 2000, 3000],
}
)
with pytest.raises(ValueError) as excinfo:
DataReader(dataframe=missing_values_df)
assert "contains missing values" in str(excinfo.value)
def test_get_raw_dataset_not_set():
"""Test get_raw_dataset when not set."""
# Mock pandas.read_csv to raise FileNotFoundError
with patch(
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
):
with pytest.raises(FileNotFoundError):
reader = DataReader(
filepath_or_buffer="dummy.csv",
sep=",",
names=["userId", "itemId", "rating", "timestamp"],
)
def test_make_consecutive_ids_in_dataset(valid_df):
"""Test making IDs consecutive."""
# Create a DataFrame with non-consecutive IDs
df = pd.DataFrame(
{
"userId": [100, 200, 300, 100],
"itemId": [500, 600, 700, 800],
"rating": [4.5, 3.0, 5.0, 2.0],
"timestamp": [1000, 2000, 3000, 4000],
}
)
reader = DataReader(dataframe=df)
reader.make_consecutive_ids_in_dataset()
# Check that IDs are now consecutive
assert set(reader.dataset["userId"]) == {0, 1, 2}
assert set(reader.dataset["itemId"]) == {0, 1, 2, 3}
# Check mapping
assert reader.get_original_user_id(0) == 100
assert reader.get_original_item_id(0) == 500
def test_make_consecutive_ids_not_set():
"""Test make_consecutive_ids_in_dataset when dataset is not set."""
with patch(
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
):
with pytest.raises(FileNotFoundError):
reader = DataReader(
filepath_or_buffer="dummy.csv",
sep=",",
names=["userId", "itemId", "rating", "timestamp"],
)
def test_binarize(valid_df):
"""Test binarization of ratings."""
reader = DataReader(dataframe=valid_df)
reader.binarize(binary_threshold=3.5)
# Check that ratings are binarized
expected = [1, 0, 1, 0] # 4.5 > 3.5, 3.0 <= 3.5, 5.0 > 3.5, 2.0 <= 3.5
assert list(reader.dataset["rating"]) == expected
def test_binarize_not_set():
"""Test binarize when dataset is not set."""
with patch(
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
):
with pytest.raises(FileNotFoundError):
reader = DataReader(
filepath_or_buffer="dummy.csv",
sep=",",
names=["userId", "itemId", "rating", "timestamp"],
)
def test_get_original_user_id_not_set(valid_df):
"""Test get_original_user_id when mapping is not set."""
# Create reader without calling make_consecutive_ids_in_dataset
reader = DataReader(dataframe=valid_df)
with pytest.raises(ValueError) as excinfo:
reader.get_original_user_id(0)
assert "ID mapping not set" in str(excinfo.value)
@@ -0,0 +1,16 @@
import pytest
from pygrex.config import cfg
from pygrex.data_reader.data_reader import DataReader
@pytest.fixture
def data_reader():
return DataReader(**cfg.data.testdata)
def test_import(data_reader):
assert data_reader.num_user == 249
assert data_reader.num_item == 551
assert data_reader.dataset.shape[0] == 1000
assert data_reader.dataset.shape[1] == 4
@@ -0,0 +1,266 @@
import pytest
import pandas as pd
import numpy as np
from pathlib import Path
import tempfile
import shutil
from unittest.mock import patch
from pygrex.data_reader import DataReader, GroupInteractionHandler
@pytest.fixture
def test_environment():
"""Set up test environment before each test."""
# Create a temporary directory for test files
test_dir = Path(tempfile.mkdtemp())
# Create test group files
group1_file = test_dir / "group1.txt"
group2_file = test_dir / "group2.txt"
group1_file.write_text("1_2_3\n4_5_6\n7_8_9")
group2_file.write_text("10_11_12\n13_14_15")
# Create test data for DataReader mock
test_data = pd.DataFrame(
{
"userId": [1, 1, 2, 2, 3, 3, 4, 4, 5],
"itemId": [101, 102, 102, 103, 101, 104, 105, 106, 107],
"rating": [5.0, 4.0, 3.0, 5.0, 4.0, 3.0, 5.0, 4.0, 3.0],
}
)
# Initialize the handler
handler = GroupInteractionHandler(test_dir)
yield {
"test_dir": test_dir,
"group1_file": group1_file,
"group2_file": group2_file,
"test_data": test_data,
"handler": handler,
}
# Clean up test environment after each test
shutil.rmtree(test_dir)
def test_init_with_directory(test_environment):
"""Test initialization with a directory path."""
handler = GroupInteractionHandler(test_environment["test_dir"])
assert len(handler.filepath_or_buffer) == 2
file_paths = set(handler.filepath_or_buffer)
assert str(test_environment["group1_file"]) in file_paths
assert str(test_environment["group2_file"]) in file_paths
def test_init_with_file_list(test_environment):
"""Test initialization with a list of file paths."""
file_list = [test_environment["group1_file"], test_environment["group2_file"]]
handler = GroupInteractionHandler(file_list)
assert len(handler.filepath_or_buffer) == 2
assert str(test_environment["group1_file"]) in handler.filepath_or_buffer
assert str(test_environment["group2_file"]) in handler.filepath_or_buffer
def test_init_with_single_file(test_environment):
"""Test initialization with a single file path."""
handler = GroupInteractionHandler(test_environment["group1_file"])
assert len(handler.filepath_or_buffer) == 1
assert handler.filepath_or_buffer[0] == str(test_environment["group1_file"])
def test_get_group_filepath(test_environment):
"""Test retrieving a group file path."""
handler = test_environment["handler"]
path = handler._get_group_filepath("group1")
assert path == str(test_environment["group1_file"].resolve())
with pytest.raises(ValueError) as exc_info:
handler._get_group_filepath("nonexistent")
assert (
str(exc_info.value)
== "Error: No file found containing 'nonexistent' in its name."
)
def test_get_group_filepath_non_existent_file(test_environment):
"""Test _get_group_filepath with a non-existent file."""
non_existent_file = test_environment["test_dir"] / "missing.txt"
handler = GroupInteractionHandler([non_existent_file])
with pytest.raises(ValueError) as exc_info:
handler._get_group_filepath("missing")
assert (
str(exc_info.value)
== f"Error: File does not exist: {non_existent_file.resolve()}"
)
def test_read_groups(test_environment):
"""Test reading groups from a file."""
handler = test_environment["handler"]
groups = handler.read_groups("group1")
assert len(groups) == 3
assert groups == ["1_2_3", "4_5_6", "7_8_9"]
with pytest.raises(ValueError) as exc_info:
handler.read_groups("")
assert str(exc_info.value) == "Groups path not specified in configuration"
def test_parse_group_members(test_environment):
"""Test parsing group members from a group ID string."""
handler = test_environment["handler"]
members = handler.parse_group_members("1_2_3")
assert members == [1, 2, 3]
members = handler.parse_group_members(" 4_5_6 ")
assert members == [4, 5, 6]
def test_get_group_members(test_environment):
"""Test getting group members from string or list."""
handler = test_environment["handler"]
# Test with string
members = handler.get_group_members("1_2_3")
assert members == [1, 2, 3]
# Test with list
members = handler.get_group_members([1, 2, 3])
assert members == [1, 2, 3]
# Test empty string
members = handler.get_group_members("")
assert members == []
# Test invalid string
with pytest.raises(ValueError):
handler.get_group_members("1_a_3")
# Test invalid type
with pytest.raises(TypeError):
handler.get_group_members(123)
@patch("pygrex.data_reader.DataReader", spec=DataReader)
def test_create_modified_dataset(mock_data_reader, test_environment):
"""Test creating a modified dataset."""
mock_reader = mock_data_reader()
mock_reader.dataset = test_environment["test_data"].copy()
mock_reader.get_new_user_id.side_effect = lambda x: x
mock_reader.get_new_item_id.side_effect = lambda x: x
handler = test_environment["handler"]
# Test with DataReader as original_data
result_df = handler.create_modified_dataset(mock_reader, [1, 2], [101, 102])
assert len(result_df) == 6
assert not (
(result_df.userId.isin([1, 2])) & (result_df.itemId.isin([101, 102]))
).any()
# Test with DataFrame and DataReader
result_df = handler.create_modified_dataset(
test_environment["test_data"], [1, 2], [101, 102], mock_reader
)
assert len(result_df) == 6
assert not (
(result_df.userId.isin([1, 2])) & (result_df.itemId.isin([101, 102]))
).any()
# Test invalid input
with pytest.raises(ValueError):
handler.create_modified_dataset(
test_environment["test_data"], [1, 2], [101, 102]
)
@patch("pygrex.data_reader.DataReader", spec=DataReader)
def test_get_rated_items_by_all_group_members(mock_data_reader, test_environment):
"""Test getting items rated by any group member."""
mock_reader = mock_data_reader()
mock_reader.dataset = test_environment["test_data"].copy()
mock_reader.get_new_user_id.side_effect = lambda x: x
mock_reader.get_original_item_id.side_effect = lambda x: x
handler = test_environment["handler"]
group = [1, 2, 3]
rated_items = handler.get_rated_items_by_all_group_members(group, mock_reader)
expected_items = np.array([101, 102, 103, 104])
np.testing.assert_array_equal(np.sort(rated_items), np.sort(expected_items))
@patch("pygrex.data_reader.DataReader", spec=DataReader)
def test_get_common_rated_items(mock_data_reader, test_environment):
"""Test getting items rated by all group members."""
mock_reader = mock_data_reader()
mock_reader.dataset = test_environment["test_data"].copy()
mock_reader.get_new_user_id.side_effect = lambda x: x
mock_reader.get_original_item_id.side_effect = lambda x: x
handler = test_environment["handler"]
# Test group with common item
group = [1, 2]
common_items = handler.get_common_rated_items(group, mock_reader)
np.testing.assert_array_equal(common_items, np.array([102]))
# Test group with no common items
group = [1, 4]
common_items = handler.get_common_rated_items(group, mock_reader)
assert len(common_items) == 0
# Test empty group
group = []
common_items = handler.get_common_rated_items(group, mock_reader)
assert len(common_items) == 0
np.testing.assert_array_equal(common_items, np.array([]))
def test_get_items_for_group_recommendation(test_environment):
"""Test getting items for group recommendation."""
handler = test_environment["handler"]
data = pd.DataFrame(
{"userId": [1, 1, 2, 2, 3], "itemId": [101, 102, 103, 104, 105]}
)
all_items = np.array([101, 102, 103, 104, 105, 106, 107])
group = [1, 2]
result = handler.get_items_for_group_recommendation(data, all_items, group)
expected = np.array([105, 106, 107])
np.testing.assert_array_equal(np.sort(result), np.sort(expected))
@patch("pygrex.data_reader.DataReader", spec=DataReader)
def test_get_group_preferences(mock_data_reader, test_environment):
"""Test getting preferences for group members."""
mock_reader = mock_data_reader()
mock_reader.dataset = test_environment["test_data"].copy()
mock_reader.get_new_user_id.side_effect = lambda x: x
handler = test_environment["handler"]
group = [1, 2]
prefs = handler.get_group_preferences(group, mock_reader)
assert len(prefs) == 4
assert prefs.userId.isin([1, 2]).all()
@patch("pygrex.data_reader.DataReader", spec=DataReader)
def test_get_group_preferences_edge_cases(mock_data_reader, test_environment):
"""Test get_group_preferences with string IDs and empty group."""
mock_reader = mock_data_reader()
mock_reader.dataset = test_environment["test_data"].copy()
mock_reader.get_new_user_id.side_effect = (
lambda x: int(x) if isinstance(x, str) else x
)
handler = test_environment["handler"]
# Test with string IDs
group = ["1", "2"]
prefs = handler.get_group_preferences(group, mock_reader)
assert len(prefs) == 4
assert prefs.userId.isin([1, 2]).all()
# Test with empty group
group = []
prefs = handler.get_group_preferences(group, mock_reader)
assert len(prefs) == 0
assert prefs.empty
+36
View File
@@ -0,0 +1,36 @@
import pytest
import torch
from pygrex.config import cfg
from pygrex.data_reader.data_reader import DataReader
from pygrex.data_reader.user_item_dict import UserItemDict
@pytest.fixture
def setup_data():
data = DataReader(**cfg.data.testdata)
data.make_consecutive_ids_in_dataset()
return data
def test_user_item_matrix(setup_data):
data = setup_data
# Get dimensions for expl_matrix
n_users = data.dataset.userId.nunique()
n_items = data.dataset.itemId.nunique()
# Create a zero-filled expl_matrix
expl_matrix = torch.zeros((n_users, n_items))
# Set expl to False for this test
expl = False
# Initialize UserItemDict with all required parameters
user_dict = UserItemDict(data.dataset, expl_matrix, expl)
# Test that the first rating is correctly stored
x = data.dataset.userId[0]
y = data.dataset.itemId[0]
v = data.dataset.rating[0]
assert user_dict[x][y] == v