Files
py-grex/pygrex/explain/groups/sliding_window_explainer.py
2026-05-22 10:02:10 +02:00

435 lines
15 KiB
Python

import itertools
from typing import Dict, List, Sequence, Union
from pygrex.data_reader import DataReader, GroupInteractionHandler
from pygrex.models import RecommenderModel
from pygrex.recommender import GroupRecommender
from pygrex.utils import SlidingWindowRanker, SlidingWindow, AggregationStrategy
class SlidingWindowExplainer:
"""
Stratigi, M., Bikakis, N., Stefanidis, K.: Counterfactual explanations for group
recommendations. In: Proceedings of the 27th International Workshop on Design,
Optimization, Languages and Analytical Processing of Big Data (DOLAP 2025).
A class that uses a sliding window approach to find counterfactual explanations
for group recommendation systems.
This class helps identify which items, if removed from the group's interaction history,
would cause a specific target item to no longer appear in the group recommendations.
"""
def __init__(
self,
config,
data: DataReader,
group_handler: GroupInteractionHandler,
members: List[Union[str, int]],
target_item: Union[str, int],
model: RecommenderModel,
aggregation_strategy: AggregationStrategy = AggregationStrategy.AVG_PREDICTIONS,
window_size=3,
):
"""
Initialize the SlidingWindowExplainer.
Args:
config: Configuration object with model parameters
data: DataReader object containing the dataset
group_handler: Object that handles group data modifications
members: List of user IDs in the group
target_item: The item ID for which explanation is sought
model: Recommender model to use for predictions,
aggregation_strategy: Strategy to aggregate individual recommendations,
window_size: Size of the sliding window
"""
self.cfg = config
self.data = data
self.group_handler = group_handler
self.members = members
self.target_item = target_item
self.model = model
self.aggregation_strategy = aggregation_strategy
self.window_size = window_size
# Results tracking
self.explanations_found: Dict[int, Dict] = {}
self.calls = 0
self.max_calls = 1000
self.item_metrics = {}
def set_sliding_window(self, sliding_window):
"""Set the sliding window object if not provided during initialization."""
self.sliding_window = sliding_window
def set_item_metrics(self, metrics: Dict[Union[str, int], Dict[str, float]]):
"""Store the pre-calculated metric scores for all items."""
self.item_metrics = metrics
def find_explanation(
self,
items_rated_by_group: List[Union[str, int]],
group_predictions: Dict,
top_recommendation: Union[str, int],
ranking_weights: Dict[str, float],
) -> Dict[int, Dict]:
"""
Find counterfactual explanations using the full, encapsulated process.
Args:
items_rated_by_group: All items rated by any member of the group.
group_predictions: The original individual predictions from the recommender.
top_recommendation: The original top recommended item.
ranking_weights: The weights from the UI for each ranking component.
Returns:
A dictionary of found explanations, including their justification metrics.
"""
self.calls = 0
ranker = SlidingWindowRanker(config={})
ranker.set_group_recommender_values(group_predictions, top_recommendation)
ranked_items, self.item_metrics = ranker.generate_ranked_items(
all_rated_items=items_rated_by_group,
data=self.data,
group_members=self.members,
component_weights=ranking_weights,
)
sliding_window = SlidingWindow(
sequence=ranked_items, window_size=self.window_size
)
found = 0
while True:
# Get the sliding window
big_window = sliding_window.get_next_window()
# Check exit conditions
if big_window is None or found > 0 or self.calls >= self.max_calls:
break
# Count calls and windows
self.calls += 1
# Test if removing this window affects recommendations
if self._test_window_removal(big_window, self.target_item):
# A counterfactual explanation has been found
found += 1
# Look for minimal subsets within this window
self._find_minimal_subset(big_window, self.target_item)
if found == 0:
print("Explanation could not be found")
return self.explanations_found
def _test_window_removal(
self, item_ids: List[Union[str, int]], original_group_rec: Union[str, int]
) -> bool:
"""
Test if removing the given items affects the group recommendation.
Args:
item_ids: List of item IDs to remove from group interactions
original_group_rec: The original recommendation to compare against
Returns:
bool: True if removing these items changes recommendations, False otherwise
"""
# Get new recommendations after removing items
group_recommendation = self._get_recommendations_after_removal(item_ids)
# Check if target item is still in recommendations
return original_group_rec not in group_recommendation
def _get_recommendations_after_removal(
self, item_ids: List[Union[str, int]], top_n: int = 10
) -> Sequence[Union[str, int]]:
"""
Get group recommendations after removing specified items from interaction history.
Args:
item_ids: List of item IDs to remove from group interactions
top_n: Number of top recommendations to return
Returns:
List of recommended item IDs
"""
# Create modified dataset with items removed
changed_data = self.group_handler.create_modified_dataset(
original_data=self.data.dataset,
group_ids=self.members,
item_ids=item_ids,
data=self.data,
)
# Create new DataReader and retrain model
data_retrained = self._create_data_reader_and_prepare(changed_data)
model_retrained = self._retrain_model(data_retrained)
# Set up recommender with new model and data
group_recommender = GroupRecommender(data_retrained)
group_recommender.setup_recommendation(
model_retrained,
self.members,
data_retrained,
aggregation_strategy=self.aggregation_strategy,
)
recommendations = group_recommender.get_group_recommendations(top_n)
if not isinstance(recommendations, list):
return []
return recommendations
def _create_data_reader_and_prepare(self, changed_data):
"""
Create and prepare a new DataReader with modified data.
Args:
changed_data: DataFrame with modified dataset
Returns:
DataReader: A new DataReader object with the modified dataset
"""
data_retrained = DataReader(
filepath_or_buffer=None,
sep=None,
names=None,
skiprows=0,
dataframe=changed_data,
)
# Fix for potential dataset issue in original code
# data_retrained.dataset = data_retrained.dataset.iloc[1:].reset_index(drop=True)
# Prepare data
data_retrained.make_consecutive_ids_in_dataset()
data_retrained.binarize(binary_threshold=1)
return data_retrained
def _retrain_model(self, data):
"""
Retrain the recommendation model with modified data.
Args:
data: Prepared DataReader object with modified dataset
Returns:
Retrained model
"""
self.model.fit(data)
return self.model
def _find_minimal_subset(
self, big_window: List[Union[str, int]], original_group_rec: Union[str, int]
) -> None:
"""
Find minimal subset of items that act as counterfactual explanation.
Args:
big_window: List of item IDs to search within
original_group_rec: The original recommendation to compare against
"""
found_subset = 0
# Try combinations of different lengths
for length in range(1, len(big_window) + 1):
if found_subset > 0 or self.calls > self.max_calls:
break
combinations = itertools.combinations(big_window, length)
for item_combo in combinations:
if found_subset > 0 or self.calls > self.max_calls:
break
subset_items = list(item_combo)
self.calls += 1
# Get recommendations after removing this subset
new_recommendations = self._get_recommendations_after_removal(
subset_items
)
# Check if this is a counterfactual explanation
if original_group_rec not in new_recommendations:
found_subset += 1
self._record_explanation(
subset_items, original_group_rec, new_recommendations[0]
)
def _record_explanation(
self,
explanation_items: List[Union[str, int]],
original_rec: Union[str, int],
new_rec: Union[str, int],
) -> None:
"""
Record and display found explanation.
Args:
explanation_items: Items that form the counterfactual explanation
original_rec: Original recommendation
new_rec: New top recommendation after removing explanation items
"""
print(
f"If the group had not interacted with these items {explanation_items},\n"
f"the item of interest {original_rec} would not have appeared on the recommendation list;\n"
f"instead, {new_rec} would have been recommended."
)
# print("")
# print(f"Explanation: {explanation_items} : found at call: {self.calls}")
# Calculate metrics for the explanation
item_intensity = self._calculate_item_intensity(explanation_items)
user_intensity = self._calculate_user_intensity(explanation_items)
explanation_metrics = {
item: self.item_metrics.get(item, {}) for item in explanation_items
}
self.explanations_found[self.calls] = {
"items": explanation_items,
"new_rec": new_rec,
"metrics": explanation_metrics,
}
exp_size = len(explanation_items)
# print(f"{exp_size}\t{self.calls}\t{item_intensity}\t{user_intensity}")
def _calculate_item_intensity(self, items: List[Union[str, int]]) -> List[float]:
"""
Calculate average item intensity for explanation items.
Args:
items: List of item IDs in the explanation
Returns:
List of average intensity scores for each item
"""
return self._calculate_average_item_intensity_score(
items, self.members, self.data
)
def _calculate_user_intensity(self, items: List[Union[str, int]]) -> List[float]:
"""
Calculate user intensity score for explanation items.
Args:
items: List of item IDs in the explanation
Returns:
List of intensity scores for each user
"""
return self._calculate_user_intensity_score(items, self.members, self.data)
@staticmethod
def _calculate_average_item_intensity_score(
explanation: List[Union[str, int]],
members: List[Union[str, int]],
data: DataReader,
) -> List[float]:
"""
Calculate the average item intensity for a counterfactual explanation.
Average item intensity is defined as the average number of interactions
between group members and each item in the explanation.
Args:
explanation: The counterfactual explanation items.
members: User IDs of the group members.
data: DataReader object containing the dataset and ID mapping methods.
Returns:
list: Average intensity for each item in the explanation.
"""
internal_group_ids = []
# Convert user IDs to internal representation
for user_id in members:
new_user_id = data.get_new_user_id(user_id)
if isinstance(new_user_id, list):
if new_user_id: # Check that the list is not empty
internal_group_ids.append(int(new_user_id[0]))
else:
internal_group_ids.append(int(new_user_id))
group_size = len(members)
item_intensities = []
for item_id in explanation:
# Convert item ID to internal representation
internal_item_id = data.get_new_item_id(item_id)
# Count interactions between this item and group members
interactions_count = len(
data.dataset[
(data.dataset.itemId == internal_item_id)
& (data.dataset.userId.isin(internal_group_ids))
]
)
# Calculate average intensity
average_intensity = interactions_count / group_size
item_intensities.append(average_intensity)
return item_intensities
@staticmethod
def _calculate_user_intensity_score(
explanation_items: List[Union[str, int]],
members: List[Union[str, int]],
data: DataReader,
) -> List[float]:
"""
Calculate the interaction intensity for each user based on their interactions with items in an explanation.
Interaction intensity represents how much a user has interacted with the items in the explanation,
normalized by the total number of explanation items.
Args
explanation_items : List of item IDs in the explanation
members : List of user IDs to calculate intensity for
data : DataReader object containing the dataset and ID mapping methods
Returns
List of interaction intensities for each user (same order as members)
Values range from 0 to 1, where:
- 0 means no interaction with any explanation item
- 1 means interaction with all explanation items
Notes
Intensity is calculated as: (number of user interactions with explanation items) / (number of explanation items)
"""
# Convert external item IDs to internal IDs
internal_item_ids = [
data.get_new_item_id(item_id) for item_id in explanation_items
]
user_intensities = []
num_explanation_items = len(explanation_items)
for member in members:
# Convert external user ID to internal ID
internal_user_id = data.get_new_user_id(member)
# Count interactions between this user and explanation items
user_interactions_count = len(
data.dataset[
(data.dataset.itemId.isin(internal_item_ids))
& (data.dataset.userId == internal_user_id)
]
)
# Calculate intensity as proportion of explanation items the user interacted with
intensity = user_interactions_count / num_explanation_items
user_intensities.append(intensity)
return user_intensities