public code v1
This commit is contained in:
@@ -0,0 +1,434 @@
|
||||
import itertools
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
from pygrex.data_reader import DataReader, GroupInteractionHandler
|
||||
from pygrex.models import RecommenderModel
|
||||
from pygrex.recommender import GroupRecommender
|
||||
from pygrex.utils import SlidingWindowRanker, SlidingWindow, AggregationStrategy
|
||||
|
||||
|
||||
class SlidingWindowExplainer:
|
||||
"""
|
||||
Stratigi, M., Bikakis, N., Stefanidis, K.: Counterfactual explanations for group
|
||||
recommendations. In: Proceedings of the 27th International Workshop on Design,
|
||||
Optimization, Languages and Analytical Processing of Big Data (DOLAP 2025).
|
||||
|
||||
A class that uses a sliding window approach to find counterfactual explanations
|
||||
for group recommendation systems.
|
||||
|
||||
This class helps identify which items, if removed from the group's interaction history,
|
||||
would cause a specific target item to no longer appear in the group recommendations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
data: DataReader,
|
||||
group_handler: GroupInteractionHandler,
|
||||
members: List[Union[str, int]],
|
||||
target_item: Union[str, int],
|
||||
model: RecommenderModel,
|
||||
aggregation_strategy: AggregationStrategy = AggregationStrategy.AVG_PREDICTIONS,
|
||||
window_size=3,
|
||||
):
|
||||
"""
|
||||
Initialize the SlidingWindowExplainer.
|
||||
|
||||
Args:
|
||||
config: Configuration object with model parameters
|
||||
data: DataReader object containing the dataset
|
||||
group_handler: Object that handles group data modifications
|
||||
members: List of user IDs in the group
|
||||
target_item: The item ID for which explanation is sought
|
||||
model: Recommender model to use for predictions,
|
||||
aggregation_strategy: Strategy to aggregate individual recommendations,
|
||||
window_size: Size of the sliding window
|
||||
"""
|
||||
self.cfg = config
|
||||
self.data = data
|
||||
self.group_handler = group_handler
|
||||
self.members = members
|
||||
self.target_item = target_item
|
||||
self.model = model
|
||||
self.aggregation_strategy = aggregation_strategy
|
||||
self.window_size = window_size
|
||||
|
||||
# Results tracking
|
||||
self.explanations_found: Dict[int, Dict] = {}
|
||||
self.calls = 0
|
||||
self.max_calls = 1000
|
||||
self.item_metrics = {}
|
||||
|
||||
def set_sliding_window(self, sliding_window):
|
||||
"""Set the sliding window object if not provided during initialization."""
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
def set_item_metrics(self, metrics: Dict[Union[str, int], Dict[str, float]]):
|
||||
"""Store the pre-calculated metric scores for all items."""
|
||||
self.item_metrics = metrics
|
||||
|
||||
def find_explanation(
|
||||
self,
|
||||
items_rated_by_group: List[Union[str, int]],
|
||||
group_predictions: Dict,
|
||||
top_recommendation: Union[str, int],
|
||||
ranking_weights: Dict[str, float],
|
||||
) -> Dict[int, Dict]:
|
||||
"""
|
||||
Find counterfactual explanations using the full, encapsulated process.
|
||||
|
||||
Args:
|
||||
items_rated_by_group: All items rated by any member of the group.
|
||||
group_predictions: The original individual predictions from the recommender.
|
||||
top_recommendation: The original top recommended item.
|
||||
ranking_weights: The weights from the UI for each ranking component.
|
||||
|
||||
Returns:
|
||||
A dictionary of found explanations, including their justification metrics.
|
||||
"""
|
||||
|
||||
self.calls = 0
|
||||
ranker = SlidingWindowRanker(config={})
|
||||
ranker.set_group_recommender_values(group_predictions, top_recommendation)
|
||||
ranked_items, self.item_metrics = ranker.generate_ranked_items(
|
||||
all_rated_items=items_rated_by_group,
|
||||
data=self.data,
|
||||
group_members=self.members,
|
||||
component_weights=ranking_weights,
|
||||
)
|
||||
|
||||
sliding_window = SlidingWindow(
|
||||
sequence=ranked_items, window_size=self.window_size
|
||||
)
|
||||
|
||||
found = 0
|
||||
while True:
|
||||
# Get the sliding window
|
||||
big_window = sliding_window.get_next_window()
|
||||
|
||||
# Check exit conditions
|
||||
if big_window is None or found > 0 or self.calls >= self.max_calls:
|
||||
break
|
||||
|
||||
# Count calls and windows
|
||||
self.calls += 1
|
||||
|
||||
# Test if removing this window affects recommendations
|
||||
if self._test_window_removal(big_window, self.target_item):
|
||||
# A counterfactual explanation has been found
|
||||
found += 1
|
||||
# Look for minimal subsets within this window
|
||||
self._find_minimal_subset(big_window, self.target_item)
|
||||
|
||||
if found == 0:
|
||||
print("Explanation could not be found")
|
||||
|
||||
return self.explanations_found
|
||||
|
||||
def _test_window_removal(
|
||||
self, item_ids: List[Union[str, int]], original_group_rec: Union[str, int]
|
||||
) -> bool:
|
||||
"""
|
||||
Test if removing the given items affects the group recommendation.
|
||||
|
||||
Args:
|
||||
item_ids: List of item IDs to remove from group interactions
|
||||
original_group_rec: The original recommendation to compare against
|
||||
|
||||
Returns:
|
||||
bool: True if removing these items changes recommendations, False otherwise
|
||||
"""
|
||||
|
||||
# Get new recommendations after removing items
|
||||
group_recommendation = self._get_recommendations_after_removal(item_ids)
|
||||
|
||||
# Check if target item is still in recommendations
|
||||
|
||||
return original_group_rec not in group_recommendation
|
||||
|
||||
def _get_recommendations_after_removal(
|
||||
self, item_ids: List[Union[str, int]], top_n: int = 10
|
||||
) -> Sequence[Union[str, int]]:
|
||||
"""
|
||||
Get group recommendations after removing specified items from interaction history.
|
||||
|
||||
Args:
|
||||
item_ids: List of item IDs to remove from group interactions
|
||||
top_n: Number of top recommendations to return
|
||||
|
||||
Returns:
|
||||
List of recommended item IDs
|
||||
"""
|
||||
# Create modified dataset with items removed
|
||||
changed_data = self.group_handler.create_modified_dataset(
|
||||
original_data=self.data.dataset,
|
||||
group_ids=self.members,
|
||||
item_ids=item_ids,
|
||||
data=self.data,
|
||||
)
|
||||
|
||||
# Create new DataReader and retrain model
|
||||
data_retrained = self._create_data_reader_and_prepare(changed_data)
|
||||
model_retrained = self._retrain_model(data_retrained)
|
||||
|
||||
# Set up recommender with new model and data
|
||||
group_recommender = GroupRecommender(data_retrained)
|
||||
group_recommender.setup_recommendation(
|
||||
model_retrained,
|
||||
self.members,
|
||||
data_retrained,
|
||||
aggregation_strategy=self.aggregation_strategy,
|
||||
)
|
||||
recommendations = group_recommender.get_group_recommendations(top_n)
|
||||
|
||||
if not isinstance(recommendations, list):
|
||||
return []
|
||||
|
||||
return recommendations
|
||||
|
||||
def _create_data_reader_and_prepare(self, changed_data):
|
||||
"""
|
||||
Create and prepare a new DataReader with modified data.
|
||||
|
||||
Args:
|
||||
changed_data: DataFrame with modified dataset
|
||||
|
||||
Returns:
|
||||
DataReader: A new DataReader object with the modified dataset
|
||||
"""
|
||||
data_retrained = DataReader(
|
||||
filepath_or_buffer=None,
|
||||
sep=None,
|
||||
names=None,
|
||||
skiprows=0,
|
||||
dataframe=changed_data,
|
||||
)
|
||||
|
||||
# Fix for potential dataset issue in original code
|
||||
# data_retrained.dataset = data_retrained.dataset.iloc[1:].reset_index(drop=True)
|
||||
|
||||
# Prepare data
|
||||
data_retrained.make_consecutive_ids_in_dataset()
|
||||
data_retrained.binarize(binary_threshold=1)
|
||||
|
||||
return data_retrained
|
||||
|
||||
def _retrain_model(self, data):
|
||||
"""
|
||||
Retrain the recommendation model with modified data.
|
||||
|
||||
Args:
|
||||
data: Prepared DataReader object with modified dataset
|
||||
|
||||
Returns:
|
||||
Retrained model
|
||||
"""
|
||||
self.model.fit(data)
|
||||
return self.model
|
||||
|
||||
def _find_minimal_subset(
|
||||
self, big_window: List[Union[str, int]], original_group_rec: Union[str, int]
|
||||
) -> None:
|
||||
"""
|
||||
Find minimal subset of items that act as counterfactual explanation.
|
||||
|
||||
Args:
|
||||
big_window: List of item IDs to search within
|
||||
original_group_rec: The original recommendation to compare against
|
||||
|
||||
"""
|
||||
found_subset = 0
|
||||
|
||||
# Try combinations of different lengths
|
||||
for length in range(1, len(big_window) + 1):
|
||||
if found_subset > 0 or self.calls > self.max_calls:
|
||||
break
|
||||
|
||||
combinations = itertools.combinations(big_window, length)
|
||||
for item_combo in combinations:
|
||||
if found_subset > 0 or self.calls > self.max_calls:
|
||||
break
|
||||
|
||||
subset_items = list(item_combo)
|
||||
self.calls += 1
|
||||
|
||||
# Get recommendations after removing this subset
|
||||
new_recommendations = self._get_recommendations_after_removal(
|
||||
subset_items
|
||||
)
|
||||
|
||||
# Check if this is a counterfactual explanation
|
||||
if original_group_rec not in new_recommendations:
|
||||
found_subset += 1
|
||||
self._record_explanation(
|
||||
subset_items, original_group_rec, new_recommendations[0]
|
||||
)
|
||||
|
||||
def _record_explanation(
|
||||
self,
|
||||
explanation_items: List[Union[str, int]],
|
||||
original_rec: Union[str, int],
|
||||
new_rec: Union[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Record and display found explanation.
|
||||
|
||||
Args:
|
||||
explanation_items: Items that form the counterfactual explanation
|
||||
original_rec: Original recommendation
|
||||
new_rec: New top recommendation after removing explanation items
|
||||
"""
|
||||
print(
|
||||
f"If the group had not interacted with these items {explanation_items},\n"
|
||||
f"the item of interest {original_rec} would not have appeared on the recommendation list;\n"
|
||||
f"instead, {new_rec} would have been recommended."
|
||||
)
|
||||
# print("")
|
||||
# print(f"Explanation: {explanation_items} : found at call: {self.calls}")
|
||||
|
||||
# Calculate metrics for the explanation
|
||||
item_intensity = self._calculate_item_intensity(explanation_items)
|
||||
user_intensity = self._calculate_user_intensity(explanation_items)
|
||||
explanation_metrics = {
|
||||
item: self.item_metrics.get(item, {}) for item in explanation_items
|
||||
}
|
||||
|
||||
self.explanations_found[self.calls] = {
|
||||
"items": explanation_items,
|
||||
"new_rec": new_rec,
|
||||
"metrics": explanation_metrics,
|
||||
}
|
||||
|
||||
exp_size = len(explanation_items)
|
||||
|
||||
# print(f"{exp_size}\t{self.calls}\t{item_intensity}\t{user_intensity}")
|
||||
|
||||
def _calculate_item_intensity(self, items: List[Union[str, int]]) -> List[float]:
|
||||
"""
|
||||
Calculate average item intensity for explanation items.
|
||||
|
||||
Args:
|
||||
items: List of item IDs in the explanation
|
||||
|
||||
Returns:
|
||||
List of average intensity scores for each item
|
||||
"""
|
||||
|
||||
return self._calculate_average_item_intensity_score(
|
||||
items, self.members, self.data
|
||||
)
|
||||
|
||||
def _calculate_user_intensity(self, items: List[Union[str, int]]) -> List[float]:
|
||||
"""
|
||||
Calculate user intensity score for explanation items.
|
||||
|
||||
Args:
|
||||
items: List of item IDs in the explanation
|
||||
|
||||
Returns:
|
||||
List of intensity scores for each user
|
||||
"""
|
||||
return self._calculate_user_intensity_score(items, self.members, self.data)
|
||||
|
||||
@staticmethod
|
||||
def _calculate_average_item_intensity_score(
|
||||
explanation: List[Union[str, int]],
|
||||
members: List[Union[str, int]],
|
||||
data: DataReader,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate the average item intensity for a counterfactual explanation.
|
||||
|
||||
Average item intensity is defined as the average number of interactions
|
||||
between group members and each item in the explanation.
|
||||
|
||||
Args:
|
||||
explanation: The counterfactual explanation items.
|
||||
members: User IDs of the group members.
|
||||
data: DataReader object containing the dataset and ID mapping methods.
|
||||
|
||||
Returns:
|
||||
list: Average intensity for each item in the explanation.
|
||||
"""
|
||||
internal_group_ids = []
|
||||
# Convert user IDs to internal representation
|
||||
for user_id in members:
|
||||
new_user_id = data.get_new_user_id(user_id)
|
||||
if isinstance(new_user_id, list):
|
||||
if new_user_id: # Check that the list is not empty
|
||||
internal_group_ids.append(int(new_user_id[0]))
|
||||
else:
|
||||
internal_group_ids.append(int(new_user_id))
|
||||
|
||||
group_size = len(members)
|
||||
item_intensities = []
|
||||
|
||||
for item_id in explanation:
|
||||
# Convert item ID to internal representation
|
||||
internal_item_id = data.get_new_item_id(item_id)
|
||||
|
||||
# Count interactions between this item and group members
|
||||
interactions_count = len(
|
||||
data.dataset[
|
||||
(data.dataset.itemId == internal_item_id)
|
||||
& (data.dataset.userId.isin(internal_group_ids))
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate average intensity
|
||||
average_intensity = interactions_count / group_size
|
||||
item_intensities.append(average_intensity)
|
||||
|
||||
return item_intensities
|
||||
|
||||
@staticmethod
|
||||
def _calculate_user_intensity_score(
|
||||
explanation_items: List[Union[str, int]],
|
||||
members: List[Union[str, int]],
|
||||
data: DataReader,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate the interaction intensity for each user based on their interactions with items in an explanation.
|
||||
|
||||
Interaction intensity represents how much a user has interacted with the items in the explanation,
|
||||
normalized by the total number of explanation items.
|
||||
|
||||
Args
|
||||
explanation_items : List of item IDs in the explanation
|
||||
members : List of user IDs to calculate intensity for
|
||||
data : DataReader object containing the dataset and ID mapping methods
|
||||
|
||||
Returns
|
||||
List of interaction intensities for each user (same order as members)
|
||||
Values range from 0 to 1, where:
|
||||
- 0 means no interaction with any explanation item
|
||||
- 1 means interaction with all explanation items
|
||||
|
||||
Notes
|
||||
Intensity is calculated as: (number of user interactions with explanation items) / (number of explanation items)
|
||||
"""
|
||||
# Convert external item IDs to internal IDs
|
||||
internal_item_ids = [
|
||||
data.get_new_item_id(item_id) for item_id in explanation_items
|
||||
]
|
||||
|
||||
user_intensities = []
|
||||
num_explanation_items = len(explanation_items)
|
||||
|
||||
for member in members:
|
||||
# Convert external user ID to internal ID
|
||||
internal_user_id = data.get_new_user_id(member)
|
||||
|
||||
# Count interactions between this user and explanation items
|
||||
user_interactions_count = len(
|
||||
data.dataset[
|
||||
(data.dataset.itemId.isin(internal_item_ids))
|
||||
& (data.dataset.userId == internal_user_id)
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate intensity as proportion of explanation items the user interacted with
|
||||
intensity = user_interactions_count / num_explanation_items
|
||||
user_intensities.append(intensity)
|
||||
|
||||
return user_intensities
|
||||
Reference in New Issue
Block a user