315 lines
10 KiB
Python
315 lines
10 KiB
Python
"""Rule-based group recommendation explainer module."""
|
|
|
|
from typing import Dict, List, Optional, Set, Union
|
|
import logging
|
|
|
|
from pygrex.data_reader.data_reader import DataReader
|
|
from pygrex.utils.association_rules import AssociationRules
|
|
|
|
# Type aliases for better readability
|
|
ItemId = Union[str, int]
|
|
MemberId = Union[str, int]
|
|
UserHistory = Dict[MemberId, Set[ItemId]]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RuleBasedGroupRecExplainer:
|
|
"""
|
|
A class to explain group recommendations using rule-based methods.
|
|
|
|
This class provides methods to generate explanations for group recommendations
|
|
based on association rules and user interaction history.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rules: AssociationRules,
|
|
data: DataReader,
|
|
pool_recommendations: Optional[Union[List[ItemId], ItemId]] = None,
|
|
members: Optional[List[MemberId]] = None,
|
|
user_history: Optional[UserHistory] = None,
|
|
min_members_threshold: int = 1,
|
|
) -> None:
|
|
"""
|
|
Initialize the RuleBasedGroupRecExplainer.
|
|
|
|
Args:
|
|
rules: An instance of AssociationRules containing the rules for explanations.
|
|
pool_recommendations: A list of item IDs to explain, or a single item ID.
|
|
members: A list of member IDs in the group.
|
|
user_history: A dictionary mapping member IDs to sets of item IDs
|
|
they have interacted with.
|
|
min_members_threshold: Minimum number of members that must satisfy
|
|
the rule for it to be considered valid.
|
|
|
|
Raises:
|
|
ValueError: If min_members_threshold is less than 1.
|
|
"""
|
|
if min_members_threshold < 1:
|
|
raise ValueError("min_members_threshold must be at least 1")
|
|
|
|
self.rules = rules
|
|
self.members = members or []
|
|
self.min_members_threshold = min_members_threshold
|
|
self.user_history = user_history or {}
|
|
self.data = data
|
|
|
|
# Normalize pool_recommendations to always be a list
|
|
self.pool_recommendations = self._normalize_recommendations(
|
|
pool_recommendations
|
|
)
|
|
|
|
def _normalize_recommendations(
|
|
self, recommendations: Optional[Union[List[ItemId], ItemId]]
|
|
) -> List[ItemId]:
|
|
"""
|
|
Normalize recommendations input to a list format.
|
|
|
|
Args:
|
|
recommendations: Single item ID, list of item IDs, or None.
|
|
|
|
Returns:
|
|
List of item IDs.
|
|
"""
|
|
if recommendations is None:
|
|
return []
|
|
|
|
if isinstance(recommendations, (str, int)):
|
|
return [recommendations]
|
|
|
|
return recommendations
|
|
|
|
def _is_rule_satisfied_by_member(
|
|
self, member: MemberId, antecedent: Set[ItemId]
|
|
) -> bool:
|
|
"""
|
|
Check if a member satisfies the rule's antecedent.
|
|
|
|
Args:
|
|
member: The member ID to check.
|
|
antecedent: The set of items that form the rule's antecedent.
|
|
|
|
Returns:
|
|
True if the member's history contains all items in the antecedent.
|
|
"""
|
|
|
|
member_history = self.user_history.get(member, set())
|
|
member_history_str = {str(item) for item in member_history}
|
|
|
|
x = member_history_str.issuperset(antecedent)
|
|
return x
|
|
|
|
def _count_satisfied_members(self, antecedent: Set[ItemId]) -> int:
|
|
"""
|
|
Count how many members satisfy the given antecedent.
|
|
|
|
Args:
|
|
antecedent: The set of items that form the rule's antecedent.
|
|
|
|
Returns:
|
|
Number of members whose history satisfies the antecedent.
|
|
"""
|
|
return sum(
|
|
1
|
|
for member in self.members
|
|
if self._is_rule_satisfied_by_member(member, antecedent)
|
|
)
|
|
|
|
def _find_applicable_rules(self, item_id: ItemId):
|
|
"""
|
|
Find rules that have the given item in their consequents.
|
|
|
|
Args:
|
|
item_id: The item ID to find rules for.
|
|
|
|
Returns:
|
|
DataFrame containing applicable rules.
|
|
"""
|
|
item_id = self.data.get_new_item_id(item_id) # type: ignore
|
|
|
|
applicable_rules = self.rules[ # type: ignore
|
|
self.rules["consequents"].apply(lambda x: str(item_id) in x) # type: ignore
|
|
]
|
|
|
|
return applicable_rules
|
|
|
|
def find_explanation(self) -> float:
|
|
"""
|
|
Generate explanations for the group recommendations based on the rules.
|
|
|
|
Returns:
|
|
The fidelity of the explanations, which is the ratio of explained
|
|
recommendations to total recommendations in the pool.
|
|
"""
|
|
if not self.pool_recommendations:
|
|
logger.warning("No recommendations to explain")
|
|
return 0.0
|
|
|
|
explained_count = 0
|
|
total_recommendations = len(self.pool_recommendations)
|
|
|
|
for item_id in self.pool_recommendations:
|
|
if self._can_explain_item(item_id):
|
|
explained_count += 1
|
|
|
|
fidelity = explained_count / total_recommendations
|
|
logger.info(
|
|
f"Explained {explained_count}/{total_recommendations} recommendations "
|
|
f"(fidelity: {fidelity:.3f})"
|
|
)
|
|
|
|
return fidelity
|
|
|
|
def _can_explain_item(self, item_id: ItemId) -> bool:
|
|
"""
|
|
Check if an item can be explained by any rule.
|
|
|
|
Args:
|
|
item_id: The item ID to check.
|
|
|
|
Returns:
|
|
True if at least one rule can explain the item.
|
|
"""
|
|
applicable_rules = self._find_applicable_rules(item_id)
|
|
|
|
for _, rule in applicable_rules.iterrows():
|
|
antecedent = rule["antecedents"]
|
|
satisfied_count = self._count_satisfied_members(antecedent)
|
|
|
|
if satisfied_count >= self.min_members_threshold:
|
|
logger.debug(f"Rule fired for item {item_id}")
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_explanation_details(self) -> Dict[ItemId, List[Dict]]:
|
|
"""
|
|
Get detailed explanations for each recommendation.
|
|
|
|
Returns:
|
|
Dictionary mapping item IDs to lists of applicable rule details.
|
|
"""
|
|
explanations = {}
|
|
|
|
for item_id in self.pool_recommendations:
|
|
item_explanations = []
|
|
applicable_rules = self._find_applicable_rules(item_id)
|
|
|
|
for _, rule in applicable_rules.iterrows():
|
|
antecedent = rule["antecedents"]
|
|
satisfied_count = self._count_satisfied_members(antecedent)
|
|
|
|
if satisfied_count >= self.min_members_threshold:
|
|
item_explanations.append(
|
|
{
|
|
"antecedent": antecedent,
|
|
"consequent": rule["consequents"],
|
|
"satisfied_members": satisfied_count,
|
|
"confidence": rule.get("confidence", "N/A"),
|
|
"support": rule.get("support", "N/A"),
|
|
}
|
|
)
|
|
|
|
explanations[item_id] = item_explanations
|
|
|
|
return explanations
|
|
|
|
def compute_group_fidelity_advanced(self) -> float:
|
|
"""
|
|
Compute group fidelity using advanced conditions.
|
|
|
|
This method implements a more sophisticated fidelity calculation where:
|
|
- Condition 1: Each member of the group must have seen at least one item from the antecedent
|
|
- Condition 2: Each item in the antecedent must have been seen by at least one member
|
|
|
|
Returns:
|
|
The fidelity score as a float between 0 and 1.
|
|
"""
|
|
if not self.pool_recommendations:
|
|
logger.warning("No recommendations to explain")
|
|
return 0.0
|
|
|
|
if not self.members:
|
|
logger.warning("No group members defined")
|
|
return 0.0
|
|
|
|
explained_count = 0
|
|
total_recommendations = len(self.pool_recommendations)
|
|
|
|
# Convert member IDs to set for faster lookup
|
|
members_set = set(self.members)
|
|
|
|
# Get all items seen by any group member
|
|
all_seen_items = set()
|
|
for member in members_set:
|
|
member_history = self.user_history.get(member, set())
|
|
# Convert to strings for consistency with rules
|
|
member_history_str = {str(item) for item in member_history}
|
|
all_seen_items.update(member_history_str)
|
|
|
|
for item_id in self.pool_recommendations:
|
|
if self._can_explain_item_advanced(item_id, members_set, all_seen_items):
|
|
explained_count += 1
|
|
|
|
fidelity = explained_count / total_recommendations
|
|
logger.info(
|
|
f"Advanced explanation: {explained_count}/{total_recommendations} recommendations "
|
|
f"(fidelity: {fidelity:.3f})"
|
|
)
|
|
|
|
return fidelity
|
|
|
|
def _can_explain_item_advanced(
|
|
self, item_id: ItemId, members_set: Set[MemberId], all_seen_items: Set[str]
|
|
) -> bool:
|
|
"""
|
|
Check if an item can be explained using advanced conditions.
|
|
|
|
Args:
|
|
item_id: The item ID to check.
|
|
members_set: Set of group member IDs.
|
|
all_seen_items: Set of all items seen by any group member.
|
|
|
|
Returns:
|
|
True if the item can be explained by at least one rule satisfying both conditions.
|
|
"""
|
|
applicable_rules = self._find_applicable_rules(item_id)
|
|
|
|
for _, rule in applicable_rules.iterrows():
|
|
antecedent = rule["antecedents"]
|
|
|
|
# Condition 1: Each member must have seen at least one item from the antecedent
|
|
cond1 = all(
|
|
self._member_has_antecedent_item(member, antecedent)
|
|
for member in members_set
|
|
)
|
|
|
|
# Condition 2: Each item in the antecedent must have been seen by at least one member
|
|
cond2 = antecedent.issubset(all_seen_items)
|
|
|
|
if cond1 and cond2:
|
|
logger.debug(f"Advanced rule fired for item {item_id}")
|
|
return True
|
|
|
|
return False
|
|
|
|
def _member_has_antecedent_item(
|
|
self, member: MemberId, antecedent: Set[ItemId]
|
|
) -> bool:
|
|
"""
|
|
Check if a member has seen at least one item from the antecedent.
|
|
|
|
Args:
|
|
member: The member ID to check.
|
|
antecedent: The set of items in the rule's antecedent.
|
|
|
|
Returns:
|
|
True if the member has seen at least one item from the antecedent.
|
|
"""
|
|
member_history = self.user_history.get(member, set())
|
|
member_history_str = {str(item) for item in member_history}
|
|
|
|
# Check if there's any intersection between member history and antecedent
|
|
return len(antecedent.intersection(member_history_str)) > 0
|