Files
py-grex/pygrex/explain/groups/rule_based_group_rec_explainer.py
2026-05-22 10:02:10 +02:00

315 lines
10 KiB
Python

"""Rule-based group recommendation explainer module."""
from typing import Dict, List, Optional, Set, Union
import logging
from pygrex.data_reader.data_reader import DataReader
from pygrex.utils.association_rules import AssociationRules
# Type aliases for better readability
ItemId = Union[str, int]
MemberId = Union[str, int]
UserHistory = Dict[MemberId, Set[ItemId]]
logger = logging.getLogger(__name__)
class RuleBasedGroupRecExplainer:
"""
A class to explain group recommendations using rule-based methods.
This class provides methods to generate explanations for group recommendations
based on association rules and user interaction history.
"""
def __init__(
self,
rules: AssociationRules,
data: DataReader,
pool_recommendations: Optional[Union[List[ItemId], ItemId]] = None,
members: Optional[List[MemberId]] = None,
user_history: Optional[UserHistory] = None,
min_members_threshold: int = 1,
) -> None:
"""
Initialize the RuleBasedGroupRecExplainer.
Args:
rules: An instance of AssociationRules containing the rules for explanations.
pool_recommendations: A list of item IDs to explain, or a single item ID.
members: A list of member IDs in the group.
user_history: A dictionary mapping member IDs to sets of item IDs
they have interacted with.
min_members_threshold: Minimum number of members that must satisfy
the rule for it to be considered valid.
Raises:
ValueError: If min_members_threshold is less than 1.
"""
if min_members_threshold < 1:
raise ValueError("min_members_threshold must be at least 1")
self.rules = rules
self.members = members or []
self.min_members_threshold = min_members_threshold
self.user_history = user_history or {}
self.data = data
# Normalize pool_recommendations to always be a list
self.pool_recommendations = self._normalize_recommendations(
pool_recommendations
)
def _normalize_recommendations(
self, recommendations: Optional[Union[List[ItemId], ItemId]]
) -> List[ItemId]:
"""
Normalize recommendations input to a list format.
Args:
recommendations: Single item ID, list of item IDs, or None.
Returns:
List of item IDs.
"""
if recommendations is None:
return []
if isinstance(recommendations, (str, int)):
return [recommendations]
return recommendations
def _is_rule_satisfied_by_member(
self, member: MemberId, antecedent: Set[ItemId]
) -> bool:
"""
Check if a member satisfies the rule's antecedent.
Args:
member: The member ID to check.
antecedent: The set of items that form the rule's antecedent.
Returns:
True if the member's history contains all items in the antecedent.
"""
member_history = self.user_history.get(member, set())
member_history_str = {str(item) for item in member_history}
x = member_history_str.issuperset(antecedent)
return x
def _count_satisfied_members(self, antecedent: Set[ItemId]) -> int:
"""
Count how many members satisfy the given antecedent.
Args:
antecedent: The set of items that form the rule's antecedent.
Returns:
Number of members whose history satisfies the antecedent.
"""
return sum(
1
for member in self.members
if self._is_rule_satisfied_by_member(member, antecedent)
)
def _find_applicable_rules(self, item_id: ItemId):
"""
Find rules that have the given item in their consequents.
Args:
item_id: The item ID to find rules for.
Returns:
DataFrame containing applicable rules.
"""
item_id = self.data.get_new_item_id(item_id) # type: ignore
applicable_rules = self.rules[ # type: ignore
self.rules["consequents"].apply(lambda x: str(item_id) in x) # type: ignore
]
return applicable_rules
def find_explanation(self) -> float:
"""
Generate explanations for the group recommendations based on the rules.
Returns:
The fidelity of the explanations, which is the ratio of explained
recommendations to total recommendations in the pool.
"""
if not self.pool_recommendations:
logger.warning("No recommendations to explain")
return 0.0
explained_count = 0
total_recommendations = len(self.pool_recommendations)
for item_id in self.pool_recommendations:
if self._can_explain_item(item_id):
explained_count += 1
fidelity = explained_count / total_recommendations
logger.info(
f"Explained {explained_count}/{total_recommendations} recommendations "
f"(fidelity: {fidelity:.3f})"
)
return fidelity
def _can_explain_item(self, item_id: ItemId) -> bool:
"""
Check if an item can be explained by any rule.
Args:
item_id: The item ID to check.
Returns:
True if at least one rule can explain the item.
"""
applicable_rules = self._find_applicable_rules(item_id)
for _, rule in applicable_rules.iterrows():
antecedent = rule["antecedents"]
satisfied_count = self._count_satisfied_members(antecedent)
if satisfied_count >= self.min_members_threshold:
logger.debug(f"Rule fired for item {item_id}")
return True
return False
def get_explanation_details(self) -> Dict[ItemId, List[Dict]]:
"""
Get detailed explanations for each recommendation.
Returns:
Dictionary mapping item IDs to lists of applicable rule details.
"""
explanations = {}
for item_id in self.pool_recommendations:
item_explanations = []
applicable_rules = self._find_applicable_rules(item_id)
for _, rule in applicable_rules.iterrows():
antecedent = rule["antecedents"]
satisfied_count = self._count_satisfied_members(antecedent)
if satisfied_count >= self.min_members_threshold:
item_explanations.append(
{
"antecedent": antecedent,
"consequent": rule["consequents"],
"satisfied_members": satisfied_count,
"confidence": rule.get("confidence", "N/A"),
"support": rule.get("support", "N/A"),
}
)
explanations[item_id] = item_explanations
return explanations
def compute_group_fidelity_advanced(self) -> float:
"""
Compute group fidelity using advanced conditions.
This method implements a more sophisticated fidelity calculation where:
- Condition 1: Each member of the group must have seen at least one item from the antecedent
- Condition 2: Each item in the antecedent must have been seen by at least one member
Returns:
The fidelity score as a float between 0 and 1.
"""
if not self.pool_recommendations:
logger.warning("No recommendations to explain")
return 0.0
if not self.members:
logger.warning("No group members defined")
return 0.0
explained_count = 0
total_recommendations = len(self.pool_recommendations)
# Convert member IDs to set for faster lookup
members_set = set(self.members)
# Get all items seen by any group member
all_seen_items = set()
for member in members_set:
member_history = self.user_history.get(member, set())
# Convert to strings for consistency with rules
member_history_str = {str(item) for item in member_history}
all_seen_items.update(member_history_str)
for item_id in self.pool_recommendations:
if self._can_explain_item_advanced(item_id, members_set, all_seen_items):
explained_count += 1
fidelity = explained_count / total_recommendations
logger.info(
f"Advanced explanation: {explained_count}/{total_recommendations} recommendations "
f"(fidelity: {fidelity:.3f})"
)
return fidelity
def _can_explain_item_advanced(
self, item_id: ItemId, members_set: Set[MemberId], all_seen_items: Set[str]
) -> bool:
"""
Check if an item can be explained using advanced conditions.
Args:
item_id: The item ID to check.
members_set: Set of group member IDs.
all_seen_items: Set of all items seen by any group member.
Returns:
True if the item can be explained by at least one rule satisfying both conditions.
"""
applicable_rules = self._find_applicable_rules(item_id)
for _, rule in applicable_rules.iterrows():
antecedent = rule["antecedents"]
# Condition 1: Each member must have seen at least one item from the antecedent
cond1 = all(
self._member_has_antecedent_item(member, antecedent)
for member in members_set
)
# Condition 2: Each item in the antecedent must have been seen by at least one member
cond2 = antecedent.issubset(all_seen_items)
if cond1 and cond2:
logger.debug(f"Advanced rule fired for item {item_id}")
return True
return False
def _member_has_antecedent_item(
self, member: MemberId, antecedent: Set[ItemId]
) -> bool:
"""
Check if a member has seen at least one item from the antecedent.
Args:
member: The member ID to check.
antecedent: The set of items in the rule's antecedent.
Returns:
True if the member has seen at least one item from the antecedent.
"""
member_history = self.user_history.get(member, set())
member_history_str = {str(item) for item in member_history}
# Check if there's any intersection between member history and antecedent
return len(antecedent.intersection(member_history_str)) > 0