public code v1
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_df():
|
||||
"""Fixture providing a valid test DataFrame."""
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3, 1],
|
||||
"itemId": [100, 200, 300, 400],
|
||||
"rating": [4.5, 3.0, 5.0, 2.0],
|
||||
"timestamp": [1000, 2000, 3000, 4000],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_csv_path(valid_df):
|
||||
"""Fixture providing a temporary CSV file with valid data."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_csv_path = os.path.join(temp_dir, "test_data.csv")
|
||||
valid_df.to_csv(test_csv_path, sep=",", index=False)
|
||||
yield test_csv_path
|
||||
|
||||
|
||||
def test_init_with_dataframe(valid_df):
|
||||
"""Test initialization with a DataFrame."""
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
pd.testing.assert_frame_equal(reader.dataset, valid_df)
|
||||
assert reader.num_user == 3
|
||||
assert reader.num_item == 4
|
||||
|
||||
|
||||
def test_init_with_filepath(test_csv_path):
|
||||
"""Test initialization with filepath."""
|
||||
reader = DataReader(
|
||||
filepath_or_buffer=test_csv_path,
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
skiprows=1, # Skip header
|
||||
)
|
||||
# Check that dataset was loaded correctly
|
||||
assert reader.num_user == 3
|
||||
assert reader.num_item == 4
|
||||
|
||||
|
||||
def test_dataset_loading_file_not_found():
|
||||
"""Test exception when file is not found."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataReader(
|
||||
filepath_or_buffer="nonexistent.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_dataset_validation_invalid_columns():
|
||||
"""Test validation for invalid columns."""
|
||||
# Create a DataFrame with non-numeric userId
|
||||
invalid_df = pd.DataFrame(
|
||||
{
|
||||
"userId": ["user1", "user2", "user3"],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": [4.5, 3.0, 5.0],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=invalid_df)
|
||||
assert "cannot be converted to a numeric type" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_empty():
|
||||
"""Test validation for empty DataFrame."""
|
||||
empty_df = pd.DataFrame(columns=["userId", "itemId", "rating", "timestamp"])
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=empty_df)
|
||||
assert "DataFrame cannot be empty" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_invalid_types():
|
||||
"""Test validation for invalid data types."""
|
||||
# Create DataFrame with invalid rating (string)
|
||||
invalid_types_df = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, 3],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": ["high", "medium", "low"],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=invalid_types_df)
|
||||
assert "cannot be converted to a numeric type" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_dataset_validation_missing_values():
|
||||
"""Test validation for missing values."""
|
||||
# Create DataFrame with NaN values
|
||||
missing_values_df = pd.DataFrame(
|
||||
{
|
||||
"userId": [1, 2, np.nan],
|
||||
"itemId": [100, 200, 300],
|
||||
"rating": [4.5, 3.0, 5.0],
|
||||
"timestamp": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
DataReader(dataframe=missing_values_df)
|
||||
assert "contains missing values" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_raw_dataset_not_set():
|
||||
"""Test get_raw_dataset when not set."""
|
||||
# Mock pandas.read_csv to raise FileNotFoundError
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_make_consecutive_ids_in_dataset(valid_df):
|
||||
"""Test making IDs consecutive."""
|
||||
# Create a DataFrame with non-consecutive IDs
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"userId": [100, 200, 300, 100],
|
||||
"itemId": [500, 600, 700, 800],
|
||||
"rating": [4.5, 3.0, 5.0, 2.0],
|
||||
"timestamp": [1000, 2000, 3000, 4000],
|
||||
}
|
||||
)
|
||||
|
||||
reader = DataReader(dataframe=df)
|
||||
reader.make_consecutive_ids_in_dataset()
|
||||
|
||||
# Check that IDs are now consecutive
|
||||
assert set(reader.dataset["userId"]) == {0, 1, 2}
|
||||
assert set(reader.dataset["itemId"]) == {0, 1, 2, 3}
|
||||
|
||||
# Check mapping
|
||||
assert reader.get_original_user_id(0) == 100
|
||||
assert reader.get_original_item_id(0) == 500
|
||||
|
||||
|
||||
def test_make_consecutive_ids_not_set():
|
||||
"""Test make_consecutive_ids_in_dataset when dataset is not set."""
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_binarize(valid_df):
|
||||
"""Test binarization of ratings."""
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
reader.binarize(binary_threshold=3.5)
|
||||
|
||||
# Check that ratings are binarized
|
||||
expected = [1, 0, 1, 0] # 4.5 > 3.5, 3.0 <= 3.5, 5.0 > 3.5, 2.0 <= 3.5
|
||||
assert list(reader.dataset["rating"]) == expected
|
||||
|
||||
|
||||
def test_binarize_not_set():
|
||||
"""Test binarize when dataset is not set."""
|
||||
with patch(
|
||||
"pandas.read_csv", side_effect=FileNotFoundError("File not found: dummy.csv")
|
||||
):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
reader = DataReader(
|
||||
filepath_or_buffer="dummy.csv",
|
||||
sep=",",
|
||||
names=["userId", "itemId", "rating", "timestamp"],
|
||||
)
|
||||
|
||||
|
||||
def test_get_original_user_id_not_set(valid_df):
|
||||
"""Test get_original_user_id when mapping is not set."""
|
||||
# Create reader without calling make_consecutive_ids_in_dataset
|
||||
reader = DataReader(dataframe=valid_df)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
reader.get_original_user_id(0)
|
||||
assert "ID mapping not set" in str(excinfo.value)
|
||||
Reference in New Issue
Block a user