public code v1
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from .data_reader import DataReader
|
||||
from .user_item_rating_dataset import UserItemRatingDataset
|
||||
from .group_interaction_handler import GroupInteractionHandler
|
||||
from .user_item_dict import UserItemDict
|
||||
|
||||
__all__ = [
|
||||
"DataReader",
|
||||
"UserItemRatingDataset",
|
||||
"GroupInteractionHandler",
|
||||
"UserItemDict",
|
||||
]
|
||||
@@ -0,0 +1,416 @@
|
||||
from typing import List, Optional, Union, cast
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
|
||||
class DataReader:
|
||||
def __init__(
|
||||
self,
|
||||
filepath_or_buffer: Optional[str] = None,
|
||||
sep: Optional[str] = None,
|
||||
names: Optional[List[str]] = None,
|
||||
skiprows: int = 0,
|
||||
dataframe: Optional[pd.DataFrame] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the DataReader with either a DataFrame or file parameters.
|
||||
|
||||
Args:
|
||||
filepath_or_buffer (Optional[str]): Path to the CSV file or buffer.
|
||||
sep (Optional[str]): Separator used in the CSV file.
|
||||
names (Optional[List[str]]): List of column names for the CSV file.
|
||||
skiprows (int, optional): Number of rows to skip in the CSV file. Defaults to 0.
|
||||
dataframe (Optional[pd.DataFrame], optional): A DataFrame to use directly. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `dataframe` nor valid file parameters are provided.
|
||||
FileNotFoundError: If the file cannot be found when loading from file.
|
||||
pd.errors.ParserError: If the CSV file cannot be parsed when loading from file.
|
||||
|
||||
Note:
|
||||
If `dataframe` is provided, it takes precedence, and file-related parameters
|
||||
are ignored but stored for reference. A warning is issued in this case.
|
||||
The DataFrame must contain columns: 'userId', 'itemId', 'rating', 'timestamp'.
|
||||
"""
|
||||
if dataframe is None and (not filepath_or_buffer or not sep or not names):
|
||||
raise ValueError(
|
||||
"Must provide either a DataFrame or valid file parameters."
|
||||
)
|
||||
|
||||
self.filepath_or_buffer = filepath_or_buffer
|
||||
self.sep = sep
|
||||
self.names = names
|
||||
self.skiprows = skiprows
|
||||
self._dataset = None
|
||||
self._raw_dataset = None
|
||||
self._num_user: Optional[int] = None
|
||||
self._num_item: Optional[int] = None
|
||||
self.original_user_id: Optional[pd.DataFrame] = None
|
||||
self.original_item_id: Optional[pd.DataFrame] = None
|
||||
self.new_user_id: Optional[pd.DataFrame] = None
|
||||
self.new_item_id: Optional[pd.DataFrame] = None
|
||||
|
||||
if dataframe is not None:
|
||||
if any(param is not None for param in [filepath_or_buffer, sep, names]):
|
||||
warnings.warn(
|
||||
"DataFrame provided; file parameters (filepath_or_buffer, sep, names) are ignored.",
|
||||
UserWarning,
|
||||
)
|
||||
self.dataset = dataframe
|
||||
|
||||
elif filepath_or_buffer and sep and names:
|
||||
# Eagerly load data if file parameters are provided
|
||||
try:
|
||||
assert self.filepath_or_buffer is not None
|
||||
|
||||
loaded_df = pd.read_csv(
|
||||
filepath_or_buffer=self.filepath_or_buffer,
|
||||
sep=self.sep,
|
||||
names=self.names,
|
||||
skiprows=self.skiprows,
|
||||
engine="python",
|
||||
)
|
||||
self._raw_dataset = loaded_df.copy()
|
||||
# Use the setter to handle dataset validation and setting _num_user/_num_item
|
||||
self.dataset = loaded_df
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {self.filepath_or_buffer}")
|
||||
except pd.errors.ParserError as e:
|
||||
raise pd.errors.ParserError(f"Failed to parse CSV: {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide either a DataFrame or valid file parameters."
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the dataset DataFrame.
|
||||
"""
|
||||
if self._dataset is None:
|
||||
if self._dataset is None:
|
||||
# If it reach here and _dataset is None, it means initialization failed
|
||||
# or an empty DataFrame was set.
|
||||
# This state should ideally not be reached with eager loading if file params were valid.
|
||||
raise ValueError("Dataset is not loaded or is not valid.")
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, new_data: pd.DataFrame) -> None:
|
||||
"""
|
||||
Set the dataset and compute the number of unique users and items.
|
||||
|
||||
Args:
|
||||
new_data (pd.DataFrame): The new dataset to set.
|
||||
|
||||
Raises:
|
||||
ValueError: If the DataFrame is None, empty, lacks required columns,
|
||||
or contains invalid data types/missing values.
|
||||
"""
|
||||
if new_data is None:
|
||||
raise ValueError("DataFrame cannot be None")
|
||||
if new_data.empty:
|
||||
raise ValueError("DataFrame cannot be empty")
|
||||
|
||||
# Validate data types
|
||||
for col in ["userId", "itemId", "rating"]:
|
||||
if not pd.api.types.is_numeric_dtype(new_data[col]):
|
||||
warnings.warn(
|
||||
f"Column '{col}' is not numeric. Attempting conversion.",
|
||||
UserWarning,
|
||||
)
|
||||
try:
|
||||
new_data[col] = pd.to_numeric(new_data[col])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Column '{col}' cannot be converted to a numeric type."
|
||||
)
|
||||
|
||||
# Check for missing values in essential columns
|
||||
if new_data[["userId", "itemId", "rating"]].isnull().any().any():
|
||||
raise ValueError(
|
||||
"DataFrame contains missing values in essential columns (userId, itemId, rating)."
|
||||
)
|
||||
|
||||
self._dataset = new_data
|
||||
self._raw_dataset = new_data.copy()
|
||||
self._num_user = int(self._dataset["userId"].nunique())
|
||||
self._num_item = int(self._dataset["itemId"].nunique())
|
||||
# Set the index to userId and itemId for easier access
|
||||
# Reset id mappings as they are now invalid for the new dataset
|
||||
self.original_user_id = None
|
||||
self.original_item_id = None
|
||||
self.new_user_id = None
|
||||
self.new_item_id = None
|
||||
|
||||
def get_raw_dataset(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the raw dataset as loaded from the file or initially set.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The raw dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the raw dataset is not set.
|
||||
"""
|
||||
if self._raw_dataset is None:
|
||||
raise ValueError(
|
||||
"Raw dataset is not set. Load data from file or set a DataFrame first."
|
||||
)
|
||||
return self._raw_dataset
|
||||
|
||||
@staticmethod
|
||||
def _create_id_mapping(column: pd.Series, new_column_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
Create a mapping for consecutive IDs.
|
||||
|
||||
Args:
|
||||
column (pd.Series): The column to map.
|
||||
new_column_name (str): The name of the new column for consecutive IDs.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A DataFrame with the original and mapped IDs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the column is empty.
|
||||
"""
|
||||
if column.empty:
|
||||
raise ValueError("Cannot create ID mapping for an empty column")
|
||||
unique_values = column.drop_duplicates().reset_index(drop=True)
|
||||
mapping = pd.DataFrame(
|
||||
{column.name: unique_values, new_column_name: np.arange(len(unique_values))}
|
||||
)
|
||||
return mapping
|
||||
|
||||
def make_consecutive_ids_in_dataset(self) -> None:
|
||||
"""
|
||||
Map user and item IDs to consecutive integers starting from 0 in a deterministic way.
|
||||
Modifies the dataset in-place and stores mappings for original and new IDs.
|
||||
"""
|
||||
if self._dataset is None:
|
||||
raise ValueError("Dataset must be loaded or set before mapping IDs")
|
||||
|
||||
dataset = self.dataset.copy()
|
||||
|
||||
# Get unique IDs and SORT them to ensure the mapping is identical every time.
|
||||
sorted_unique_users = sorted(dataset["userId"].unique())
|
||||
sorted_unique_items = sorted(dataset["itemId"].unique())
|
||||
|
||||
# Create user ID mapping from the sorted list
|
||||
user_id_mapping = pd.DataFrame(
|
||||
{
|
||||
"userId": sorted_unique_users,
|
||||
"new_userId": range(len(sorted_unique_users)),
|
||||
}
|
||||
)
|
||||
dataset["userId"] = dataset["userId"].map(
|
||||
user_id_mapping.set_index("userId")["new_userId"]
|
||||
)
|
||||
|
||||
# Create item ID mapping from the sorted list
|
||||
item_id_mapping = pd.DataFrame(
|
||||
{
|
||||
"itemId": sorted_unique_items,
|
||||
"new_itemId": range(len(sorted_unique_items)),
|
||||
}
|
||||
)
|
||||
dataset["itemId"] = dataset["itemId"].map(
|
||||
item_id_mapping.set_index("itemId")["new_itemId"]
|
||||
)
|
||||
|
||||
# Store mappings for lookups
|
||||
self.original_user_id = user_id_mapping.set_index("new_userId")
|
||||
self.original_item_id = item_id_mapping.set_index("new_itemId")
|
||||
self.new_user_id = user_id_mapping.set_index("userId")
|
||||
self.new_item_id = item_id_mapping.set_index("itemId")
|
||||
|
||||
# Update the internal dataset
|
||||
dataset["userId"] = dataset["userId"].astype(int)
|
||||
dataset["itemId"] = dataset["itemId"].astype(int)
|
||||
self._dataset = dataset
|
||||
|
||||
self._num_user = self._dataset["userId"].max() + 1
|
||||
self._num_item = self._dataset["itemId"].max() + 1
|
||||
|
||||
def binarize(
|
||||
self, binary_threshold: float = 1, inplace: bool = True
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Binarize ratings into 0 or 1 based on a threshold (implicit feedback).
|
||||
|
||||
Args:
|
||||
binary_threshold (float, optional): Threshold for binarization. Defaults to 1.0.
|
||||
inplace (bool, optional): If True, modify the dataset in-place. If False, return a new DataFrame.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[pd.DataFrame]: The binarized dataset if inplace=False, else None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset is not set or binary_threshold is invalid.
|
||||
|
||||
Example:
|
||||
Ratings [0.5, 2.0, 3.0] with threshold=1.0 -> [0, 1, 1]
|
||||
"""
|
||||
if self._dataset is None:
|
||||
raise ValueError("Dataset must be loaded or set before binarization")
|
||||
if not isinstance(binary_threshold, (int, float)):
|
||||
raise ValueError("binary_threshold must be a number")
|
||||
|
||||
dataset = self._dataset if inplace else self._dataset.copy()
|
||||
dataset["rating"] = (dataset["rating"] > binary_threshold).astype(int)
|
||||
|
||||
if not inplace:
|
||||
return dataset
|
||||
self._dataset = dataset
|
||||
return None
|
||||
|
||||
@property
|
||||
def num_user(self) -> int:
|
||||
"""
|
||||
Get the number of unique users.
|
||||
|
||||
Returns:
|
||||
int: Number of unique users.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset is not set.
|
||||
"""
|
||||
if self._num_user is None:
|
||||
raise ValueError("Dataset must be loaded or set to compute num_user")
|
||||
return self._num_user
|
||||
|
||||
@property
|
||||
def num_item(self) -> int:
|
||||
"""
|
||||
Get the number of unique items.
|
||||
|
||||
Returns:
|
||||
int: Number of unique items.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset is not set.
|
||||
"""
|
||||
if self._num_item is None:
|
||||
raise ValueError("Dataset must be loaded or set to compute num_item")
|
||||
return self._num_item
|
||||
|
||||
def get_original_user_id(self, u: Union[int, List[int]]) -> Union[int, List[int]]:
|
||||
"""
|
||||
Get the original user ID(s) from the new (consecutive) ID(s).
|
||||
|
||||
Args:
|
||||
u (Union[int, List[int]]): New user ID(s).
|
||||
|
||||
Returns:
|
||||
Union[int, List[int]]: Original user ID(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If ID mapping is not set or if any ID is not found.
|
||||
"""
|
||||
if self.original_user_id is None:
|
||||
raise ValueError(
|
||||
"ID mapping not set. Call make_consecutive_ids_in_dataset first"
|
||||
)
|
||||
try:
|
||||
if isinstance(u, (int, np.integer)):
|
||||
return int(self.original_user_id.loc[u, "userId"]) # type: ignore
|
||||
series = cast(pd.Series, self.original_user_id.loc[u, "userId"])
|
||||
return series.tolist()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"User ID(s) not found: {e}")
|
||||
|
||||
def get_original_item_id(self, i: Union[int, List[int]]) -> Union[int, List[int]]:
|
||||
"""
|
||||
Get the original item ID(s) from the new (consecutive) ID(s).
|
||||
|
||||
Args:
|
||||
i (Union[int, List[int]]): New item ID(s).
|
||||
|
||||
Returns:
|
||||
Union[int, List[int]]: Original item ID(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If ID mapping is not set or if any ID is not found.
|
||||
"""
|
||||
if self.original_item_id is None:
|
||||
raise ValueError(
|
||||
"ID mapping not set. Call make_consecutive_ids_in_dataset first"
|
||||
)
|
||||
try:
|
||||
if isinstance(i, (int, np.integer)):
|
||||
return int(self.original_item_id.loc[i, "itemId"]) # type: ignore
|
||||
|
||||
series = cast(pd.Series, self.original_item_id.loc[i, "itemId"])
|
||||
return series.tolist()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Item ID(s) not found: {e}")
|
||||
|
||||
def get_new_user_id(
|
||||
self, u: Union[Union[str, int], List[Union[str, int]]]
|
||||
) -> Union[int, List[int]]:
|
||||
"""
|
||||
Get the new (consecutive) user ID(s) from the original ID(s).
|
||||
|
||||
Args:
|
||||
u: Original user ID(s).
|
||||
|
||||
Returns:
|
||||
New user ID(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If ID mapping is not set or if any ID is not found.
|
||||
"""
|
||||
if self.new_user_id is None:
|
||||
raise ValueError(
|
||||
"ID mapping not set. Call make_consecutive_ids_in_dataset first"
|
||||
)
|
||||
try:
|
||||
if isinstance(u, str):
|
||||
u = int(u)
|
||||
return int(self.new_user_id.loc[u, "new_userId"]) # type: ignore
|
||||
if isinstance(u, list) and all(isinstance(x, str) for x in u):
|
||||
u = [int(x) for x in u]
|
||||
series = cast(pd.Series, self.new_user_id.loc[u, "new_userId"])
|
||||
return series.tolist()
|
||||
if isinstance(u, (int, np.integer)):
|
||||
return int(self.new_user_id.loc[u, "new_userId"]) # type: ignore
|
||||
series = cast(pd.Series, self.new_user_id.loc[u, "new_userId"])
|
||||
return series.tolist()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"User ID(s) not found: {e}")
|
||||
|
||||
def get_new_item_id(
|
||||
self, i: Union[Union[str, int], List[Union[str, int]]]
|
||||
) -> Union[int, List[int]]:
|
||||
"""
|
||||
Get the new (consecutive) item ID(s) from the original ID(s).
|
||||
|
||||
Args:
|
||||
i: Original item ID(s).
|
||||
|
||||
Returns:
|
||||
New item ID(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If ID mapping is not set or if any ID is not found.
|
||||
"""
|
||||
if self.new_item_id is None:
|
||||
raise ValueError(
|
||||
"ID mapping not set. Call make_consecutive_ids_in_dataset first"
|
||||
)
|
||||
try:
|
||||
if isinstance(i, str):
|
||||
i = int(i)
|
||||
return int(self.new_item_id.loc[i, "new_itemId"]) # type: ignore
|
||||
if isinstance(i, list) and all(isinstance(x, str) for x in i):
|
||||
i = [int(x) for x in i]
|
||||
series = cast(pd.Series, self.new_item_id.loc[i, "new_itemId"])
|
||||
return series.tolist()
|
||||
if isinstance(i, (int, np.integer)):
|
||||
return int(self.new_item_id.loc[i, "new_itemId"]) # type: ignore
|
||||
series = cast(pd.Series, self.new_item_id.loc[i, "new_itemId"])
|
||||
return series.tolist()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Item ID(s) not found: {e}")
|
||||
@@ -0,0 +1,289 @@
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
|
||||
|
||||
class GroupInteractionHandler:
|
||||
def __init__(self, filepath_or_buffer: Union[str, Path, List[Union[str, Path]]]):
|
||||
"""
|
||||
Initialize the GroupInteractionHandler.
|
||||
|
||||
Args:
|
||||
filepath_or_buffer: Path to directory containing group files or list of file paths
|
||||
"""
|
||||
# Convert to Path objects
|
||||
if isinstance(filepath_or_buffer, (str, Path)):
|
||||
path = Path(filepath_or_buffer)
|
||||
# If a single directory path is provided, get all files in it
|
||||
if path.is_dir():
|
||||
self.filepath_or_buffer = [
|
||||
str(file) for file in path.iterdir() if file.is_file()
|
||||
]
|
||||
else:
|
||||
self.filepath_or_buffer = [str(path)]
|
||||
else:
|
||||
# If a list of paths is provided, convert all to Path and then to strings
|
||||
self.filepath_or_buffer = [str(Path(p)) for p in filepath_or_buffer]
|
||||
|
||||
def _get_group_filepath(self, filename: str) -> str:
|
||||
"""
|
||||
Get a specific group file path by matching the filename.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to search for.
|
||||
|
||||
Returns:
|
||||
str: The matched file path.
|
||||
|
||||
Raises:
|
||||
ValueError: Error: File does not exist
|
||||
ValueError: No file found containing '{filename}' in its name.
|
||||
"""
|
||||
for path_str in self.filepath_or_buffer:
|
||||
if filename in path_str: # Check if filename is part of the path
|
||||
path = Path(path_str).resolve()
|
||||
if path.exists():
|
||||
return str(path)
|
||||
else:
|
||||
raise ValueError(f"Error: File does not exist: {path}")
|
||||
|
||||
raise ValueError(f"Error: No file found containing '{filename}' in its name.")
|
||||
|
||||
def read_groups(self, filename: str) -> List[str]:
|
||||
"""
|
||||
Method to read group IDs from a specified file.
|
||||
|
||||
Args:
|
||||
filename (str): Name of the file containing group IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of group IDs.
|
||||
|
||||
Raises:
|
||||
ValueError: If groups path is not specified in configuration
|
||||
"""
|
||||
if not filename:
|
||||
raise ValueError("Groups path not specified in configuration")
|
||||
|
||||
filepath = self._get_group_filepath(filename)
|
||||
|
||||
# Use Path for file reading
|
||||
path = Path(filepath)
|
||||
return [line.strip() for line in path.read_text().splitlines()]
|
||||
|
||||
def parse_group_members(self, group: str) -> List[int]:
|
||||
"""
|
||||
Parse group ID to get member IDs.
|
||||
|
||||
Args:
|
||||
group: Group ID string
|
||||
|
||||
Returns:
|
||||
List of member IDs
|
||||
"""
|
||||
group = group.strip()
|
||||
members = group.split("_")
|
||||
return [int(m) for m in members]
|
||||
|
||||
def get_group_members(self, group: Union[List[Union[int, str]], str]) -> List[int]:
|
||||
"""
|
||||
Get group members from a group ID string or list.
|
||||
|
||||
Args:
|
||||
group: Group ID string in format "id1_id2_id3" or list of IDs
|
||||
|
||||
Returns:
|
||||
List of member IDs as integers
|
||||
|
||||
Raises:
|
||||
ValueError: If any member ID cannot be converted to an integer
|
||||
TypeError: If group is neither a string nor a list
|
||||
"""
|
||||
|
||||
if isinstance(group, list):
|
||||
return [int(member) for member in group]
|
||||
|
||||
if not isinstance(group, str):
|
||||
raise TypeError(f"Expected string or list, got {type(group).__name__}")
|
||||
|
||||
group = group.strip()
|
||||
if not group:
|
||||
return []
|
||||
|
||||
try:
|
||||
return [int(member) for member in group.split("_")]
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid member ID in group: {str(e)}")
|
||||
|
||||
def create_modified_dataset(
|
||||
self,
|
||||
original_data: Union[pd.DataFrame, DataReader],
|
||||
group_ids: List[Union[int, str]],
|
||||
item_ids: List[Union[int, str]],
|
||||
data: Optional[DataReader] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Creates a modified dataset by removing interactions between specified groups and items.
|
||||
|
||||
Args:
|
||||
original_data: Either a pandas DataFrame or a DataReader object containing the dataset
|
||||
group_ids: List of group IDs to consider for removal
|
||||
item_ids: List of item IDs to consider for removal
|
||||
data: Optional DataReader object if original_data is a DataFrame
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A pandas DataFrame with the specified interactions removed
|
||||
|
||||
Raises:
|
||||
ValueError: If input data types are incorrect
|
||||
"""
|
||||
# Determine the data source and target dataset
|
||||
if isinstance(original_data, DataReader):
|
||||
data_reader = original_data
|
||||
dataset = original_data.dataset
|
||||
elif isinstance(original_data, pd.DataFrame) and isinstance(data, DataReader):
|
||||
data_reader = data
|
||||
dataset = original_data
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either original_data must be a DataReader or data must be provided as a DataReader"
|
||||
)
|
||||
|
||||
# Convert IDs to internal representation
|
||||
new_group_ids = [
|
||||
data_reader.get_new_user_id(
|
||||
int(g) if isinstance(g, (int, np.integer)) else g
|
||||
)
|
||||
for g in group_ids
|
||||
]
|
||||
|
||||
new_item_ids = [
|
||||
data_reader.get_new_item_id(
|
||||
int(i) if isinstance(i, (int, np.integer)) else i
|
||||
)
|
||||
for i in item_ids
|
||||
]
|
||||
|
||||
# Create mask for rows to keep (inverse of rows to drop)
|
||||
mask = ~(dataset.itemId.isin(new_item_ids) & dataset.userId.isin(new_group_ids))
|
||||
|
||||
return dataset[mask]
|
||||
|
||||
def get_rated_items_by_all_group_members(
|
||||
self, group: List[Union[int, str]], original_data: DataReader
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get all items rated by any member of the group.
|
||||
|
||||
Args:
|
||||
group: List of user IDs
|
||||
original_data: Data object with mapping methods
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of original item IDs rated by any group member
|
||||
"""
|
||||
# Convert group members to new user IDs
|
||||
new_group = [
|
||||
original_data.get_new_user_id(
|
||||
int(g) if isinstance(g, (int, np.integer)) else g
|
||||
)
|
||||
for g in group
|
||||
]
|
||||
|
||||
# Get unique items rated by any group member
|
||||
group_items = original_data.dataset[
|
||||
original_data.dataset.userId.isin(new_group)
|
||||
]["itemId"].unique()
|
||||
|
||||
# Convert back to original item IDs
|
||||
original_ids = original_data.get_original_item_id(group_items.tolist())
|
||||
return np.array(original_ids)
|
||||
|
||||
def get_common_rated_items(
|
||||
self, group: List[Union[int, str]], original_data: DataReader
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get items rated by all members of the group (intersection of rated items).
|
||||
|
||||
Args:
|
||||
group: List of user IDs
|
||||
original_data: DataReader object with mapping methods
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of original item IDs rated by all group members
|
||||
"""
|
||||
# Convert group members to new user IDs
|
||||
new_group = [
|
||||
original_data.get_new_user_id(
|
||||
int(g) if isinstance(g, (int, np.integer)) else g
|
||||
)
|
||||
for g in group
|
||||
]
|
||||
|
||||
# Get items rated by each group member
|
||||
rated_items_per_member = []
|
||||
for user_id in new_group:
|
||||
user_items = original_data.dataset[original_data.dataset.userId == user_id][
|
||||
"itemId"
|
||||
].unique()
|
||||
rated_items_per_member.append(set(user_items))
|
||||
|
||||
# Find intersection of all rated items
|
||||
if rated_items_per_member:
|
||||
common_items = set.intersection(*rated_items_per_member)
|
||||
common_items_array = np.array(list(common_items))
|
||||
# Convert back to original item IDs
|
||||
original_ids = original_data.get_original_item_id(
|
||||
common_items_array.tolist()
|
||||
)
|
||||
return np.array(original_ids)
|
||||
else:
|
||||
return np.array([])
|
||||
|
||||
def get_items_for_group_recommendation(
|
||||
self, data: pd.DataFrame, item_ids: np.ndarray, group: List[int]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get items for group recommendation (those not interacted with by any group member).
|
||||
|
||||
Args:
|
||||
data: DataFrame with interaction data
|
||||
item_ids: Array of all item IDs
|
||||
group: List of group member IDs
|
||||
|
||||
Returns:
|
||||
Array of item IDs not interacted with by any group member
|
||||
"""
|
||||
item_ids_group = data.loc[data.userId.isin(group), "itemId"]
|
||||
return np.setdiff1d(item_ids, item_ids_group)
|
||||
|
||||
def get_group_preferences(
|
||||
self, group: List[Union[int, str]], data_reader: DataReader
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get all preferences (ratings) by all members of the group.
|
||||
|
||||
Args:
|
||||
group: List of user IDs
|
||||
data_reader: DataReader object with the dataset
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame containing all preferences by group members
|
||||
"""
|
||||
# Convert group members to new user IDs
|
||||
new_group = [
|
||||
data_reader.get_new_user_id(
|
||||
int(g) if isinstance(g, (int, np.integer)) else g
|
||||
)
|
||||
for g in group
|
||||
]
|
||||
|
||||
# Get all interactions by group members
|
||||
group_preferences = data_reader.dataset[
|
||||
data_reader.dataset.userId.isin(new_group)
|
||||
].copy()
|
||||
|
||||
return group_preferences
|
||||
@@ -0,0 +1,36 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class UserItemDict(Dataset):
|
||||
"""Wrapper, convert <user, item, rating> Tensor into Pytorch Dataset"""
|
||||
|
||||
def __init__(self, data, expl_matrix, expl):
|
||||
"""
|
||||
args:
|
||||
|
||||
target_tensor: torch.Tensor, the corresponding rating for <user, item> pair
|
||||
"""
|
||||
|
||||
grp_data = data.groupby('userId')
|
||||
self.users_dict = dict()
|
||||
for userId, itemId_rating in grp_data:
|
||||
self.users_dict[userId] = {'items': list(itemId_rating.itemId),
|
||||
'rating': list(itemId_rating.rating)}
|
||||
self.n_items = data.itemId.nunique()
|
||||
self.n_users = data.userId.nunique()
|
||||
self.expl_matrix = expl_matrix
|
||||
self.expl = expl
|
||||
|
||||
def __getitem__(self, index):
|
||||
ratings = np.zeros(self.n_items)
|
||||
ratings[self.users_dict[index]['items']] = self.users_dict[index]['rating']
|
||||
if self.expl:
|
||||
return torch.tensor(ratings) + self.expl_matrix[index, :]
|
||||
else:
|
||||
return torch.tensor(ratings)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_users
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class UserItemRatingDataset(Dataset):
|
||||
"""Wrapper, convert <user, item, rating> Tensor into Pytorch Dataset"""
|
||||
|
||||
def __init__(self, user_tensor, item_tensor, target_tensor):
|
||||
"""
|
||||
args:
|
||||
|
||||
target_tensor: torch.Tensor, the corresponding rating for <user, item> pair
|
||||
"""
|
||||
self.user_tensor = user_tensor
|
||||
self.item_tensor = item_tensor
|
||||
self.target_tensor = target_tensor
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.user_tensor[index], self.item_tensor[index], self.target_tensor[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.user_tensor.size(0)
|
||||
Reference in New Issue
Block a user