public code v1
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
from .individual.model_based_emf import EMFExplainer
|
||||
from .individual.model_based_als_explain import ALSExplainer
|
||||
from .individual.post_hoc_association_rules import ARPostHocExplainer
|
||||
from .individual.post_hoc_knn import KNNPostHocExplainer
|
||||
from .groups.rule_based_group_rec_explainer import RuleBasedGroupRecExplainer
|
||||
from .groups.sliding_window_explainer import SlidingWindowExplainer
|
||||
from .groups.lore4groups_explainer import LORE4GroupsExplainer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EMFExplainer",
|
||||
"ALSExplainer",
|
||||
"ARPostHocExplainer",
|
||||
"KNNPostHocExplainer",
|
||||
"RuleBasedGroupRecExplainer",
|
||||
"SlidingWindowExplainer",
|
||||
"LORE4GroupsExplainer",
|
||||
]
|
||||
@@ -0,0 +1,10 @@
|
||||
from .rule_based_group_rec_explainer import RuleBasedGroupRecExplainer
|
||||
from .sliding_window_explainer import SlidingWindowExplainer
|
||||
from .lore4groups_explainer import LORE4GroupsExplainer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RuleBasedGroupRecExplainer",
|
||||
"SlidingWindowExplainer",
|
||||
"LORE4GroupsExplainer",
|
||||
]
|
||||
@@ -0,0 +1,731 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
import logging
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from typing import Dict, Set, List, Optional, Any, Tuple, Union
|
||||
from sklearn.tree import DecisionTreeClassifier, _tree
|
||||
|
||||
ItemId = Union[str, int]
|
||||
UserId = Union[str, int]
|
||||
FactualRule = List[str]
|
||||
CounterfactualSet = List[List[str]]
|
||||
Explanation = Tuple[Optional[FactualRule], Optional[CounterfactualSet]]
|
||||
|
||||
|
||||
class LORE4GroupsExplainer:
|
||||
"""
|
||||
Enhanced LORE4Groups explainer that incorporates genre information
|
||||
and stores decision trees for visualization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_profiles: Dict[str, Set[str]],
|
||||
item_label_matrix: pd.DataFrame,
|
||||
config: Dict,
|
||||
genre_profiles: Optional[Dict[str, Set[str]]] = None,
|
||||
):
|
||||
self.item_profiles = {str(k): v for k, v in item_profiles.items()}
|
||||
self.item_label_matrix = item_label_matrix
|
||||
self.params = config["explainer"]["lore4groups"]
|
||||
|
||||
# NEW: Store genre information
|
||||
self.genre_profiles = (
|
||||
{str(k): v for k, v in genre_profiles.items()} if genre_profiles else {}
|
||||
)
|
||||
|
||||
all_columns = item_label_matrix.columns.tolist()
|
||||
self.all_labels = [col for col in all_columns if col != "like"]
|
||||
|
||||
# Add 'like' back for target variable access (but not as feature)
|
||||
if "like" in all_columns:
|
||||
self.all_labels.append("like")
|
||||
|
||||
def _enhanced_jaccard_similarity(self, item1_id: ItemId, item2_id: ItemId) -> float:
|
||||
"""Enhanced Jaccard similarity that considers both tags and genres"""
|
||||
# Get regular tags
|
||||
tags1 = self.item_profiles.get(str(item1_id), set())
|
||||
tags2 = self.item_profiles.get(str(item2_id), set())
|
||||
|
||||
# Get genres and add them as features
|
||||
genres1 = self.genre_profiles.get(str(item1_id), set())
|
||||
genres2 = self.genre_profiles.get(str(item2_id), set())
|
||||
|
||||
# Combine tags and genres for enhanced similarity
|
||||
features1 = tags1.union({f"genre_{g.lower()}" for g in genres1})
|
||||
features2 = tags2.union({f"genre_{g.lower()}" for g in genres2})
|
||||
|
||||
if not features1 or not features2:
|
||||
return 0.0
|
||||
|
||||
union_len = len(features1.union(features2))
|
||||
intersection_len = len(features1.intersection(features2))
|
||||
|
||||
return intersection_len / union_len if union_len > 0 else 0.0
|
||||
|
||||
def _jaccard_similarity(self, item1_id: ItemId, item2_id: ItemId) -> float:
|
||||
"""Original jaccard similarity (kept for compatibility)"""
|
||||
tags1 = self.item_profiles.get(str(item1_id), set())
|
||||
tags2 = self.item_profiles.get(str(item2_id), set())
|
||||
if not tags1 or not tags2:
|
||||
return 0.0
|
||||
union_len = len(tags1.union(tags2))
|
||||
return len(tags1.intersection(tags2)) / union_len if union_len > 0 else 0.0
|
||||
|
||||
def _get_enhanced_similar_examples(
|
||||
self,
|
||||
user_id_consecutive: UserId,
|
||||
target_item_id: ItemId,
|
||||
user_hist: Set[ItemId],
|
||||
dataset: pd.DataFrame,
|
||||
model=None,
|
||||
data_reader=None,
|
||||
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
||||
"""Enhanced version that returns both DataFrame and metadata for visualization"""
|
||||
|
||||
# 1. Find all similar items using enhanced similarity
|
||||
similarities = [
|
||||
(seen_id, self._enhanced_jaccard_similarity(target_item_id, seen_id))
|
||||
for seen_id in user_hist
|
||||
]
|
||||
similarities = sorted(similarities, key=lambda x: x[1], reverse=True)
|
||||
|
||||
sim_th = self.params.get("similarity_threshold", 0.0)
|
||||
top_similar_items_str = {
|
||||
item[0]
|
||||
for item in similarities[: self.params["n_similar_for_tree"]]
|
||||
if item[1] >= sim_th
|
||||
}
|
||||
|
||||
if not top_similar_items_str:
|
||||
return pd.DataFrame(), {}
|
||||
|
||||
# 2. Build the local dataset
|
||||
top_similar_items_int = [int(i) for i in top_similar_items_str]
|
||||
|
||||
# Get existing ratings for similar items
|
||||
local_df = dataset[
|
||||
(dataset["userId"] == user_id_consecutive)
|
||||
& (dataset["itemId"].isin(top_similar_items_int))
|
||||
].copy()
|
||||
|
||||
rated_items = set(local_df["itemId"])
|
||||
items_to_predict = [
|
||||
item for item in top_similar_items_int if item not in rated_items
|
||||
]
|
||||
|
||||
# Add predictions for unrated items
|
||||
if model and data_reader and items_to_predict:
|
||||
try:
|
||||
orig_user_id = data_reader.get_original_user_id(
|
||||
int(user_id_consecutive)
|
||||
)
|
||||
predicted_ratings = []
|
||||
|
||||
for item_id_consecutive in items_to_predict:
|
||||
orig_item_id = data_reader.get_original_item_id(
|
||||
int(item_id_consecutive)
|
||||
)
|
||||
pred = model.predict(orig_user_id, orig_item_id)
|
||||
predicted_ratings.append(
|
||||
{
|
||||
"userId": user_id_consecutive,
|
||||
"itemId": item_id_consecutive,
|
||||
"rating": float(pred),
|
||||
}
|
||||
)
|
||||
|
||||
if predicted_ratings:
|
||||
pred_df = pd.DataFrame(predicted_ratings)
|
||||
local_df = pd.concat([local_df, pred_df], ignore_index=True)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
# Check minimum samples requirement
|
||||
if len(local_df) < 2:
|
||||
return pd.DataFrame(), {}
|
||||
|
||||
# 3. Apply thresholding with fallbacks
|
||||
rating_threshold = self.params["rating_threshold_for_like"]
|
||||
|
||||
threshold_info = {
|
||||
"was_overridden": False,
|
||||
"original_threshold": rating_threshold,
|
||||
"final_threshold": rating_threshold,
|
||||
}
|
||||
|
||||
local_df["like"] = (local_df["rating"] >= rating_threshold).astype(int)
|
||||
|
||||
# Apply fallback thresholds if needed
|
||||
like_counts = local_df["like"].value_counts()
|
||||
|
||||
if len(like_counts) < 2:
|
||||
# Try mean-based threshold
|
||||
mean_rating = local_df["rating"].mean()
|
||||
local_df["like"] = (local_df["rating"] >= mean_rating).astype(int)
|
||||
threshold_info["was_overridden"] = True
|
||||
threshold_info["final_threshold"] = mean_rating
|
||||
like_counts = local_df["like"].value_counts()
|
||||
if len(like_counts) < 2:
|
||||
return pd.DataFrame(), {}
|
||||
|
||||
# Check for severe imbalance (>90% one class)
|
||||
min_class_ratio = like_counts.min() / len(local_df)
|
||||
if min_class_ratio < 0.1:
|
||||
if like_counts.min() < 2:
|
||||
return pd.DataFrame(), {}
|
||||
|
||||
# 4. Construct the enhanced feature matrix (including genres)
|
||||
feature_labels = [label for label in self.all_labels if label != "like"]
|
||||
|
||||
examples = []
|
||||
genre_features_used = set()
|
||||
|
||||
for idx, row in local_df.iterrows():
|
||||
item_id = str(int(row["itemId"]))
|
||||
tags = self.item_profiles.get(item_id, set())
|
||||
genres = self.genre_profiles.get(item_id, set())
|
||||
|
||||
# Create base example with target variables
|
||||
example = {
|
||||
"movie_id": item_id,
|
||||
"rating": row["rating"],
|
||||
"like": int(row["like"]),
|
||||
}
|
||||
|
||||
# Add tag features (excluding 'like')
|
||||
for label in feature_labels:
|
||||
example[label] = 1 if label in tags else 0
|
||||
|
||||
# Add genre features dynamically
|
||||
for genre in genres:
|
||||
genre_feature = f"genre_{genre.lower()}"
|
||||
example[genre_feature] = 1
|
||||
genre_features_used.add(genre_feature)
|
||||
|
||||
# Also add to feature_labels if not already there
|
||||
if genre_feature not in feature_labels:
|
||||
feature_labels.append(genre_feature)
|
||||
|
||||
examples.append(example)
|
||||
|
||||
# Ensure all examples have all genre features
|
||||
for example in examples:
|
||||
for genre_feature in genre_features_used:
|
||||
if genre_feature not in example:
|
||||
example[genre_feature] = 0
|
||||
|
||||
final_df = pd.DataFrame(examples)
|
||||
|
||||
# Final validation
|
||||
if final_df["like"].nunique() < 2:
|
||||
return pd.DataFrame(), {}
|
||||
|
||||
# Prepare metadata for visualization
|
||||
metadata = {
|
||||
"feature_labels": [label for label in feature_labels if label != "like"],
|
||||
"genre_features": list(genre_features_used),
|
||||
"similarity_scores": dict(similarities[:5]), # Top 5 similarities
|
||||
"target_item_genres": self.genre_profiles.get(str(target_item_id), set()),
|
||||
"rating_threshold": threshold_info["final_threshold"],
|
||||
"threshold_info": threshold_info,
|
||||
}
|
||||
|
||||
return final_df, metadata
|
||||
|
||||
def _get_factual_path_for_item(
|
||||
self,
|
||||
clf: DecisionTreeClassifier,
|
||||
x_item: pd.DataFrame,
|
||||
metadata: Dict[str, Any],
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Traces the specific path an item takes through the decision tree
|
||||
and returns the corresponding factual rule set.
|
||||
"""
|
||||
feature_labels = metadata.get("feature_labels", [])
|
||||
if not feature_labels:
|
||||
return None
|
||||
|
||||
# 1. Get the sequence of nodes the item travels through
|
||||
node_indicator = clf.decision_path(x_item)
|
||||
node_index = node_indicator.indices[ # type: ignore
|
||||
node_indicator.indptr[0] : node_indicator.indptr[ # type: ignore
|
||||
1
|
||||
]
|
||||
]
|
||||
|
||||
rules = []
|
||||
tree = clf.tree_
|
||||
|
||||
# 2. Iterate through the path to build the rules
|
||||
# We stop at the second to last node because the last one is the leaf
|
||||
for i in range(len(node_index) - 1):
|
||||
node_id = node_index[i]
|
||||
child_node_id = node_index[i + 1]
|
||||
|
||||
# Ensure this is not a leaf node
|
||||
if tree.feature[node_id] != _tree.TREE_UNDEFINED: # type: ignore
|
||||
feature_name = feature_labels[tree.feature[node_id]] # type: ignore
|
||||
threshold = tree.threshold[node_id] # type: ignore
|
||||
|
||||
# 3. Determine if the path went left or right to form the rule
|
||||
if child_node_id == tree.children_left[node_id]: # type: ignore
|
||||
# Path went left (True condition for <= threshold)
|
||||
rule = f"{feature_name} <= {threshold:.2f}"
|
||||
else:
|
||||
# Path went right (False condition for <= threshold)
|
||||
rule = f"{feature_name} > {threshold:.2f}"
|
||||
|
||||
# Use the same enhanced formatting as before for consistency
|
||||
if feature_name.startswith("genre_"):
|
||||
genre_name = feature_name.replace("genre_", "").title()
|
||||
if child_node_id == tree.children_left[node_id]: # type: ignore
|
||||
rules.append(f"Does NOT have genre: `{genre_name}`")
|
||||
else:
|
||||
rules.append(f"Has genre: `{genre_name}`")
|
||||
else:
|
||||
rules.append(rule)
|
||||
|
||||
return rules if rules else None
|
||||
|
||||
def _train_enhanced_decision_tree(
|
||||
self,
|
||||
user_id_consecutive: UserId,
|
||||
item_id: ItemId,
|
||||
user_hist: Set[ItemId],
|
||||
dataset: pd.DataFrame,
|
||||
model=None,
|
||||
data_reader=None,
|
||||
) -> Tuple[Optional[DecisionTreeClassifier], Dict[str, Any]]:
|
||||
"""Enhanced tree training that returns both classifier and metadata"""
|
||||
|
||||
df_examples, metadata = self._get_enhanced_similar_examples(
|
||||
user_id_consecutive, item_id, user_hist, dataset, model, data_reader
|
||||
)
|
||||
|
||||
if df_examples.empty:
|
||||
return None, {}
|
||||
|
||||
like_counts = df_examples["like"].value_counts()
|
||||
|
||||
if len(like_counts) < 2 or like_counts.min() < 2:
|
||||
return None, {}
|
||||
|
||||
feature_labels = metadata.get("feature_labels", [])
|
||||
X = df_examples[feature_labels]
|
||||
y = df_examples["like"]
|
||||
|
||||
# Verify feature matrix has variance
|
||||
feature_variances = X.var()
|
||||
if (feature_variances == 0).all():
|
||||
return None, {}
|
||||
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=5, # Slightly deeper to accommodate genre features
|
||||
min_samples_split=max(4, len(df_examples) // 4),
|
||||
min_samples_leaf=2,
|
||||
random_state=42,
|
||||
class_weight="balanced",
|
||||
)
|
||||
|
||||
try:
|
||||
clf.fit(X, y)
|
||||
|
||||
# Enhanced feature importance analysis
|
||||
feature_importance = list(zip(feature_labels, clf.feature_importances_))
|
||||
important_features = [
|
||||
(f, imp) for f, imp in feature_importance if imp > 0.001
|
||||
]
|
||||
genre_important_features = [
|
||||
(f, imp) for f, imp in important_features if f.startswith("genre_")
|
||||
]
|
||||
|
||||
# Add classifier and feature info to metadata
|
||||
metadata.update(
|
||||
{
|
||||
"classifier": clf,
|
||||
"feature_importance": dict(feature_importance),
|
||||
"important_features": important_features,
|
||||
"genre_important_features": genre_important_features,
|
||||
"training_data_size": len(df_examples),
|
||||
"class_distribution": like_counts.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
return clf, metadata
|
||||
|
||||
except Exception as _:
|
||||
return None, {}
|
||||
|
||||
def _get_enhanced_explanation_path(
|
||||
self,
|
||||
clf: DecisionTreeClassifier,
|
||||
x_item: pd.DataFrame,
|
||||
metadata: Dict[str, Any],
|
||||
) -> Optional[List[str]]:
|
||||
"""Enhanced explanation path that provides better rule descriptions"""
|
||||
|
||||
if 1 not in clf.classes_:
|
||||
return None
|
||||
|
||||
leaf_id = clf.apply(x_item)[0] # type: ignore
|
||||
class_index = np.where(clf.classes_ == 1)[0]
|
||||
if not class_index.size or clf.tree_.value[leaf_id][0][class_index[0]] == 0: # type: ignore
|
||||
return None
|
||||
|
||||
node_indicator = clf.decision_path(x_item)
|
||||
node_index = node_indicator.indices[ # type: ignore
|
||||
node_indicator.indptr[0] : node_indicator.indptr[ # type: ignore
|
||||
1
|
||||
]
|
||||
]
|
||||
|
||||
rules = []
|
||||
feature_labels = metadata.get("feature_labels", [])
|
||||
|
||||
for i in range(len(node_index) - 1): # Exclude leaf node
|
||||
node_id = node_index[i]
|
||||
next_node_id = node_index[i + 1]
|
||||
|
||||
if clf.tree_.feature[node_id] != _tree.TREE_UNDEFINED: # type: ignore
|
||||
feature_name = feature_labels[clf.tree_.feature[node_id]] # type: ignore
|
||||
threshold = clf.tree_.threshold[node_id] # type: ignore
|
||||
|
||||
# Enhanced rule formatting based on feature type
|
||||
if feature_name.startswith("genre_"):
|
||||
genre_name = feature_name.replace("genre_", "").title()
|
||||
if next_node_id == clf.tree_.children_left[node_id]: # type: ignore
|
||||
rules.append(f"Does NOT have genre: `{genre_name}`")
|
||||
else:
|
||||
rules.append(f"Has genre: `{genre_name}`")
|
||||
else:
|
||||
# Regular tag features
|
||||
if next_node_id == clf.tree_.children_left[node_id]: # type: ignore
|
||||
rules.append(f"{feature_name} <= {threshold}")
|
||||
else:
|
||||
rules.append(f"{feature_name} > {threshold}")
|
||||
|
||||
return rules
|
||||
|
||||
def _generate_enhanced_individual_explanation(
|
||||
self, clf: DecisionTreeClassifier, item_id: ItemId, metadata: Dict[str, Any]
|
||||
) -> Optional[Explanation]:
|
||||
"""Enhanced individual explanation generation"""
|
||||
|
||||
if str(item_id) not in self.item_label_matrix.index:
|
||||
return None
|
||||
|
||||
x_item_full = self.item_label_matrix.loc[[str(item_id)]]
|
||||
feature_labels = metadata.get("feature_labels", [])
|
||||
|
||||
try:
|
||||
# For genre features, we need to dynamically add them to the item
|
||||
item_genres = self.genre_profiles.get(str(item_id), set())
|
||||
|
||||
# Create enhanced item representation
|
||||
enhanced_item_data = x_item_full.copy()
|
||||
|
||||
# Add genre features
|
||||
for genre in item_genres:
|
||||
genre_feature = f"genre_{genre.lower()}"
|
||||
if genre_feature in feature_labels:
|
||||
enhanced_item_data[genre_feature] = 1
|
||||
|
||||
# Ensure all genre features exist (set to 0 if not present)
|
||||
for feature in feature_labels:
|
||||
if (
|
||||
feature.startswith("genre_")
|
||||
and feature not in enhanced_item_data.columns
|
||||
):
|
||||
enhanced_item_data[feature] = 0
|
||||
|
||||
# Select only the features used in training
|
||||
x_item = enhanced_item_data[feature_labels]
|
||||
|
||||
except KeyError as _:
|
||||
return None
|
||||
# Get enhanced factual rule
|
||||
# factual_rule = self._get_enhanced_explanation_path(clf, x_item, metadata)
|
||||
factual_rule = self._get_factual_path_for_item(clf, x_item, metadata)
|
||||
|
||||
if not factual_rule:
|
||||
return None
|
||||
|
||||
# Get counterfactuals (reuse existing method)
|
||||
counterfactual_set = self._get_counterfactual_paths(clf, x_item)
|
||||
if not counterfactual_set:
|
||||
return None
|
||||
|
||||
return (factual_rule, counterfactual_set)
|
||||
|
||||
def _get_counterfactual_paths(
|
||||
self, clf: DecisionTreeClassifier, x_item: pd.DataFrame
|
||||
) -> Optional[CounterfactualSet]:
|
||||
"""Original counterfactual path method (kept for compatibility)"""
|
||||
tree = clf.tree_
|
||||
paths = []
|
||||
|
||||
def find_paths(node_id, current_path):
|
||||
if tree.feature[node_id] == _tree.TREE_UNDEFINED: # type: ignore
|
||||
class_index = np.where(clf.classes_ == 0)[0]
|
||||
if class_index.size and tree.value[node_id][0][class_index[0]] > 0:
|
||||
paths.append(list(current_path))
|
||||
return
|
||||
feature_idx = tree.feature[node_id] # type: ignore
|
||||
threshold = tree.threshold[node_id] # type: ignore
|
||||
current_path.append((feature_idx, "<=", threshold))
|
||||
find_paths(tree.children_left[node_id], current_path) # type: ignore
|
||||
current_path.pop()
|
||||
current_path.append((feature_idx, ">", threshold))
|
||||
find_paths(tree.children_right[node_id], current_path) # type: ignore
|
||||
current_path.pop()
|
||||
|
||||
find_paths(0, [])
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
min_nf = float("inf")
|
||||
counterfactuals = []
|
||||
for path in paths:
|
||||
nf = 0
|
||||
for feature_idx, op, threshold in path:
|
||||
if feature_idx < len(x_item.columns):
|
||||
item_val = x_item.iloc[0, feature_idx]
|
||||
if not (
|
||||
(op == "<=" and item_val <= threshold)
|
||||
or (op == ">" and item_val > threshold)
|
||||
):
|
||||
nf += 1
|
||||
if nf < min_nf:
|
||||
min_nf = nf
|
||||
counterfactuals = [path]
|
||||
elif nf == min_nf:
|
||||
counterfactuals.append(path)
|
||||
|
||||
# Enhanced counterfactual formatting
|
||||
formatted_counterfactuals = []
|
||||
for cf_path in counterfactuals:
|
||||
formatted_path = []
|
||||
for idx, op, _ in cf_path:
|
||||
if idx < len(x_item.columns):
|
||||
feature_name = x_item.columns[idx]
|
||||
if feature_name.startswith("genre_"):
|
||||
genre_name = feature_name.replace("genre_", "").title()
|
||||
if op == "<=":
|
||||
formatted_path.append(
|
||||
f"Does NOT have genre: `{genre_name}`"
|
||||
)
|
||||
else:
|
||||
formatted_path.append(f"Has genre: `{genre_name}`")
|
||||
else:
|
||||
formatted_path.append(f"{feature_name} {op} 0.5")
|
||||
if formatted_path:
|
||||
formatted_counterfactuals.append(formatted_path)
|
||||
|
||||
return formatted_counterfactuals if formatted_counterfactuals else None
|
||||
|
||||
def _aggregate_factual_rules(
|
||||
self, individual_explanations: Dict[UserId, List[str]], total_group_size: int
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Aggregates individual factual rules into a group consensus by finding
|
||||
the rules supported by a majority of members.
|
||||
"""
|
||||
|
||||
# Flatten the list of all rules from all users into a single list
|
||||
all_rules_flat = [
|
||||
rule
|
||||
for rules_list in individual_explanations.values()
|
||||
for rule in rules_list
|
||||
]
|
||||
|
||||
if not all_rules_flat:
|
||||
return {"unanimous": [], "majority": [], "minority": []}
|
||||
|
||||
# Count the occurrences of each rule
|
||||
rule_counts = Counter(all_rules_flat)
|
||||
|
||||
majority_threshold = (total_group_size // 2) + 1 if total_group_size > 1 else 1
|
||||
minority_threshold = 1
|
||||
cleaned_rules_set = self._clean_contradictory_rules(set(rule_counts.keys()))
|
||||
categorized_rules = {"unanimous": [], "majority": [], "minority": []}
|
||||
|
||||
for rule in sorted(list(cleaned_rules_set)):
|
||||
count = rule_counts[rule]
|
||||
rule_with_support = f"{rule} ({count}/{total_group_size} members)"
|
||||
|
||||
if count == total_group_size:
|
||||
categorized_rules["unanimous"].append(rule_with_support)
|
||||
elif count >= majority_threshold:
|
||||
categorized_rules["majority"].append(rule_with_support)
|
||||
elif count >= minority_threshold:
|
||||
categorized_rules["minority"].append(rule_with_support)
|
||||
|
||||
return categorized_rules
|
||||
|
||||
def _clean_contradictory_rules(self, rules_set: Set[str]) -> Set[str]:
|
||||
"""Enhanced contradiction cleaning that handles genre rules"""
|
||||
conditions_by_attr = {}
|
||||
|
||||
for rule in rules_set:
|
||||
# Handle genre rules
|
||||
if "Has genre:" in rule or "Does NOT have genre:" in rule:
|
||||
genre_match = re.search(r"`([^`]+)`", rule)
|
||||
if genre_match:
|
||||
genre = genre_match.group(1)
|
||||
attr = f"genre_{genre}"
|
||||
op = "has" if "Has genre:" in rule else "not_has"
|
||||
conditions_by_attr.setdefault(attr, set()).add(op)
|
||||
else:
|
||||
# Handle regular rules
|
||||
match = re.match(r"(.+?)\s*([<>]=?)\s*(\d+\.?\d*)", rule)
|
||||
if match:
|
||||
attr, op, _ = match.groups()
|
||||
conditions_by_attr.setdefault(attr.strip(), set()).add(op)
|
||||
|
||||
# Find contradictory attributes
|
||||
invalid_attrs = set()
|
||||
for attr, ops in conditions_by_attr.items():
|
||||
if attr.startswith("genre_"):
|
||||
# Genre contradiction: has and not_has same genre
|
||||
if "has" in ops and "not_has" in ops:
|
||||
invalid_attrs.add(attr)
|
||||
else:
|
||||
# Numerical contradiction: <= and >
|
||||
if any(op in ops for op in ["<=", "<"]) and any(
|
||||
op in ops for op in [">", ">="]
|
||||
):
|
||||
invalid_attrs.add(attr)
|
||||
|
||||
# Remove contradictory rules
|
||||
clean_rules = set()
|
||||
for rule in rules_set:
|
||||
is_invalid = False
|
||||
for invalid_attr in invalid_attrs:
|
||||
if invalid_attr.startswith("genre_"):
|
||||
genre = invalid_attr.replace("genre_", "")
|
||||
if f"`{genre}`" in rule:
|
||||
is_invalid = True
|
||||
break
|
||||
else:
|
||||
if invalid_attr in rule:
|
||||
is_invalid = True
|
||||
break
|
||||
|
||||
if not is_invalid:
|
||||
clean_rules.add(rule)
|
||||
|
||||
return clean_rules
|
||||
|
||||
def find_explanation(
|
||||
self,
|
||||
recommended_items: List[ItemId],
|
||||
members: List[UserId],
|
||||
user_hist: Dict[UserId, Set[ItemId]],
|
||||
dataset: pd.DataFrame,
|
||||
model=None,
|
||||
data_reader=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Enhanced explanation finding with tree storage for visualization"""
|
||||
if data_reader is None:
|
||||
raise ValueError(
|
||||
"A 'data_reader' object must be provided to find explanations."
|
||||
)
|
||||
|
||||
detailed_explanations = {}
|
||||
explainable_count = 0
|
||||
|
||||
if not recommended_items:
|
||||
return {"fidelity": 0.0, "details": {}}
|
||||
|
||||
for item_id in recommended_items:
|
||||
all_individual_rules = {}
|
||||
all_counterfactuals = {}
|
||||
stored_classifiers = {} # Store classifiers for visualization
|
||||
stored_metadata = {} # Store metadata for visualization
|
||||
representative_decision_path = None
|
||||
threshold_info_for_item = None
|
||||
|
||||
for user_id in members:
|
||||
user_id_consecutive = data_reader.get_new_user_id(user_id)
|
||||
clf, metadata = self._train_enhanced_decision_tree(
|
||||
user_id_consecutive,
|
||||
item_id,
|
||||
user_hist.get(user_id, set()),
|
||||
dataset,
|
||||
model,
|
||||
data_reader,
|
||||
)
|
||||
|
||||
if clf and metadata:
|
||||
if threshold_info_for_item is None and "threshold_info" in metadata:
|
||||
threshold_info_for_item = metadata["threshold_info"]
|
||||
|
||||
explanation = self._generate_enhanced_individual_explanation(
|
||||
clf, item_id, metadata
|
||||
)
|
||||
|
||||
if explanation:
|
||||
r, phi = explanation
|
||||
all_individual_rules[user_id] = r
|
||||
all_counterfactuals[user_id] = phi
|
||||
|
||||
if representative_decision_path is None:
|
||||
representative_decision_path = r
|
||||
# Store for visualization (use first successful classifier)
|
||||
if not stored_classifiers:
|
||||
stored_classifiers[user_id] = clf
|
||||
stored_metadata[user_id] = metadata
|
||||
|
||||
total_members_in_group = len(members)
|
||||
factual_set = self._aggregate_factual_rules(
|
||||
all_individual_rules, total_members_in_group
|
||||
)
|
||||
|
||||
if representative_decision_path and factual_set:
|
||||
explainable_count += 1
|
||||
|
||||
# Enhanced detailed explanations with visualization data
|
||||
item_explanation = {
|
||||
"decision_path": representative_decision_path,
|
||||
"group_factual_rule": factual_set,
|
||||
"individual_counterfactuals": all_counterfactuals,
|
||||
}
|
||||
|
||||
if threshold_info_for_item:
|
||||
item_explanation["threshold_info"] = threshold_info_for_item
|
||||
|
||||
# Add visualization data if available
|
||||
if stored_classifiers:
|
||||
user_id_for_viz = list(stored_classifiers.keys())[0]
|
||||
item_explanation.update(
|
||||
{
|
||||
"decision_tree": stored_classifiers[user_id_for_viz],
|
||||
"feature_names": stored_metadata[user_id_for_viz].get(
|
||||
"feature_labels", []
|
||||
),
|
||||
"tree_metadata": stored_metadata[user_id_for_viz],
|
||||
"item_genres": self.genre_profiles.get(str(item_id), set()),
|
||||
}
|
||||
)
|
||||
|
||||
detailed_explanations[item_id] = item_explanation
|
||||
|
||||
fidelity = (
|
||||
explainable_count / len(recommended_items) if recommended_items else 0.0
|
||||
)
|
||||
|
||||
group_explanations = {
|
||||
"fidelity": fidelity,
|
||||
"details": detailed_explanations,
|
||||
}
|
||||
|
||||
logging.info(
|
||||
f"Enhanced fidelity for {members}: {fidelity:.3f} ({explainable_count}/{len(recommended_items)})"
|
||||
)
|
||||
|
||||
return group_explanations
|
||||
@@ -0,0 +1,314 @@
|
||||
"""Rule-based group recommendation explainer module."""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
import logging
|
||||
|
||||
from pygrex.data_reader.data_reader import DataReader
|
||||
from pygrex.utils.association_rules import AssociationRules
|
||||
|
||||
# Type aliases for better readability
|
||||
ItemId = Union[str, int]
|
||||
MemberId = Union[str, int]
|
||||
UserHistory = Dict[MemberId, Set[ItemId]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RuleBasedGroupRecExplainer:
|
||||
"""
|
||||
A class to explain group recommendations using rule-based methods.
|
||||
|
||||
This class provides methods to generate explanations for group recommendations
|
||||
based on association rules and user interaction history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rules: AssociationRules,
|
||||
data: DataReader,
|
||||
pool_recommendations: Optional[Union[List[ItemId], ItemId]] = None,
|
||||
members: Optional[List[MemberId]] = None,
|
||||
user_history: Optional[UserHistory] = None,
|
||||
min_members_threshold: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RuleBasedGroupRecExplainer.
|
||||
|
||||
Args:
|
||||
rules: An instance of AssociationRules containing the rules for explanations.
|
||||
pool_recommendations: A list of item IDs to explain, or a single item ID.
|
||||
members: A list of member IDs in the group.
|
||||
user_history: A dictionary mapping member IDs to sets of item IDs
|
||||
they have interacted with.
|
||||
min_members_threshold: Minimum number of members that must satisfy
|
||||
the rule for it to be considered valid.
|
||||
|
||||
Raises:
|
||||
ValueError: If min_members_threshold is less than 1.
|
||||
"""
|
||||
if min_members_threshold < 1:
|
||||
raise ValueError("min_members_threshold must be at least 1")
|
||||
|
||||
self.rules = rules
|
||||
self.members = members or []
|
||||
self.min_members_threshold = min_members_threshold
|
||||
self.user_history = user_history or {}
|
||||
self.data = data
|
||||
|
||||
# Normalize pool_recommendations to always be a list
|
||||
self.pool_recommendations = self._normalize_recommendations(
|
||||
pool_recommendations
|
||||
)
|
||||
|
||||
def _normalize_recommendations(
|
||||
self, recommendations: Optional[Union[List[ItemId], ItemId]]
|
||||
) -> List[ItemId]:
|
||||
"""
|
||||
Normalize recommendations input to a list format.
|
||||
|
||||
Args:
|
||||
recommendations: Single item ID, list of item IDs, or None.
|
||||
|
||||
Returns:
|
||||
List of item IDs.
|
||||
"""
|
||||
if recommendations is None:
|
||||
return []
|
||||
|
||||
if isinstance(recommendations, (str, int)):
|
||||
return [recommendations]
|
||||
|
||||
return recommendations
|
||||
|
||||
def _is_rule_satisfied_by_member(
|
||||
self, member: MemberId, antecedent: Set[ItemId]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a member satisfies the rule's antecedent.
|
||||
|
||||
Args:
|
||||
member: The member ID to check.
|
||||
antecedent: The set of items that form the rule's antecedent.
|
||||
|
||||
Returns:
|
||||
True if the member's history contains all items in the antecedent.
|
||||
"""
|
||||
|
||||
member_history = self.user_history.get(member, set())
|
||||
member_history_str = {str(item) for item in member_history}
|
||||
|
||||
x = member_history_str.issuperset(antecedent)
|
||||
return x
|
||||
|
||||
def _count_satisfied_members(self, antecedent: Set[ItemId]) -> int:
|
||||
"""
|
||||
Count how many members satisfy the given antecedent.
|
||||
|
||||
Args:
|
||||
antecedent: The set of items that form the rule's antecedent.
|
||||
|
||||
Returns:
|
||||
Number of members whose history satisfies the antecedent.
|
||||
"""
|
||||
return sum(
|
||||
1
|
||||
for member in self.members
|
||||
if self._is_rule_satisfied_by_member(member, antecedent)
|
||||
)
|
||||
|
||||
def _find_applicable_rules(self, item_id: ItemId):
|
||||
"""
|
||||
Find rules that have the given item in their consequents.
|
||||
|
||||
Args:
|
||||
item_id: The item ID to find rules for.
|
||||
|
||||
Returns:
|
||||
DataFrame containing applicable rules.
|
||||
"""
|
||||
item_id = self.data.get_new_item_id(item_id) # type: ignore
|
||||
|
||||
applicable_rules = self.rules[ # type: ignore
|
||||
self.rules["consequents"].apply(lambda x: str(item_id) in x) # type: ignore
|
||||
]
|
||||
|
||||
return applicable_rules
|
||||
|
||||
def find_explanation(self) -> float:
|
||||
"""
|
||||
Generate explanations for the group recommendations based on the rules.
|
||||
|
||||
Returns:
|
||||
The fidelity of the explanations, which is the ratio of explained
|
||||
recommendations to total recommendations in the pool.
|
||||
"""
|
||||
if not self.pool_recommendations:
|
||||
logger.warning("No recommendations to explain")
|
||||
return 0.0
|
||||
|
||||
explained_count = 0
|
||||
total_recommendations = len(self.pool_recommendations)
|
||||
|
||||
for item_id in self.pool_recommendations:
|
||||
if self._can_explain_item(item_id):
|
||||
explained_count += 1
|
||||
|
||||
fidelity = explained_count / total_recommendations
|
||||
logger.info(
|
||||
f"Explained {explained_count}/{total_recommendations} recommendations "
|
||||
f"(fidelity: {fidelity:.3f})"
|
||||
)
|
||||
|
||||
return fidelity
|
||||
|
||||
def _can_explain_item(self, item_id: ItemId) -> bool:
|
||||
"""
|
||||
Check if an item can be explained by any rule.
|
||||
|
||||
Args:
|
||||
item_id: The item ID to check.
|
||||
|
||||
Returns:
|
||||
True if at least one rule can explain the item.
|
||||
"""
|
||||
applicable_rules = self._find_applicable_rules(item_id)
|
||||
|
||||
for _, rule in applicable_rules.iterrows():
|
||||
antecedent = rule["antecedents"]
|
||||
satisfied_count = self._count_satisfied_members(antecedent)
|
||||
|
||||
if satisfied_count >= self.min_members_threshold:
|
||||
logger.debug(f"Rule fired for item {item_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_explanation_details(self) -> Dict[ItemId, List[Dict]]:
|
||||
"""
|
||||
Get detailed explanations for each recommendation.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping item IDs to lists of applicable rule details.
|
||||
"""
|
||||
explanations = {}
|
||||
|
||||
for item_id in self.pool_recommendations:
|
||||
item_explanations = []
|
||||
applicable_rules = self._find_applicable_rules(item_id)
|
||||
|
||||
for _, rule in applicable_rules.iterrows():
|
||||
antecedent = rule["antecedents"]
|
||||
satisfied_count = self._count_satisfied_members(antecedent)
|
||||
|
||||
if satisfied_count >= self.min_members_threshold:
|
||||
item_explanations.append(
|
||||
{
|
||||
"antecedent": antecedent,
|
||||
"consequent": rule["consequents"],
|
||||
"satisfied_members": satisfied_count,
|
||||
"confidence": rule.get("confidence", "N/A"),
|
||||
"support": rule.get("support", "N/A"),
|
||||
}
|
||||
)
|
||||
|
||||
explanations[item_id] = item_explanations
|
||||
|
||||
return explanations
|
||||
|
||||
def compute_group_fidelity_advanced(self) -> float:
|
||||
"""
|
||||
Compute group fidelity using advanced conditions.
|
||||
|
||||
This method implements a more sophisticated fidelity calculation where:
|
||||
- Condition 1: Each member of the group must have seen at least one item from the antecedent
|
||||
- Condition 2: Each item in the antecedent must have been seen by at least one member
|
||||
|
||||
Returns:
|
||||
The fidelity score as a float between 0 and 1.
|
||||
"""
|
||||
if not self.pool_recommendations:
|
||||
logger.warning("No recommendations to explain")
|
||||
return 0.0
|
||||
|
||||
if not self.members:
|
||||
logger.warning("No group members defined")
|
||||
return 0.0
|
||||
|
||||
explained_count = 0
|
||||
total_recommendations = len(self.pool_recommendations)
|
||||
|
||||
# Convert member IDs to set for faster lookup
|
||||
members_set = set(self.members)
|
||||
|
||||
# Get all items seen by any group member
|
||||
all_seen_items = set()
|
||||
for member in members_set:
|
||||
member_history = self.user_history.get(member, set())
|
||||
# Convert to strings for consistency with rules
|
||||
member_history_str = {str(item) for item in member_history}
|
||||
all_seen_items.update(member_history_str)
|
||||
|
||||
for item_id in self.pool_recommendations:
|
||||
if self._can_explain_item_advanced(item_id, members_set, all_seen_items):
|
||||
explained_count += 1
|
||||
|
||||
fidelity = explained_count / total_recommendations
|
||||
logger.info(
|
||||
f"Advanced explanation: {explained_count}/{total_recommendations} recommendations "
|
||||
f"(fidelity: {fidelity:.3f})"
|
||||
)
|
||||
|
||||
return fidelity
|
||||
|
||||
def _can_explain_item_advanced(
|
||||
self, item_id: ItemId, members_set: Set[MemberId], all_seen_items: Set[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an item can be explained using advanced conditions.
|
||||
|
||||
Args:
|
||||
item_id: The item ID to check.
|
||||
members_set: Set of group member IDs.
|
||||
all_seen_items: Set of all items seen by any group member.
|
||||
|
||||
Returns:
|
||||
True if the item can be explained by at least one rule satisfying both conditions.
|
||||
"""
|
||||
applicable_rules = self._find_applicable_rules(item_id)
|
||||
|
||||
for _, rule in applicable_rules.iterrows():
|
||||
antecedent = rule["antecedents"]
|
||||
|
||||
# Condition 1: Each member must have seen at least one item from the antecedent
|
||||
cond1 = all(
|
||||
self._member_has_antecedent_item(member, antecedent)
|
||||
for member in members_set
|
||||
)
|
||||
|
||||
# Condition 2: Each item in the antecedent must have been seen by at least one member
|
||||
cond2 = antecedent.issubset(all_seen_items)
|
||||
|
||||
if cond1 and cond2:
|
||||
logger.debug(f"Advanced rule fired for item {item_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _member_has_antecedent_item(
|
||||
self, member: MemberId, antecedent: Set[ItemId]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a member has seen at least one item from the antecedent.
|
||||
|
||||
Args:
|
||||
member: The member ID to check.
|
||||
antecedent: The set of items in the rule's antecedent.
|
||||
|
||||
Returns:
|
||||
True if the member has seen at least one item from the antecedent.
|
||||
"""
|
||||
member_history = self.user_history.get(member, set())
|
||||
member_history_str = {str(item) for item in member_history}
|
||||
|
||||
# Check if there's any intersection between member history and antecedent
|
||||
return len(antecedent.intersection(member_history_str)) > 0
|
||||
@@ -0,0 +1,434 @@
|
||||
import itertools
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
from pygrex.data_reader import DataReader, GroupInteractionHandler
|
||||
from pygrex.models import RecommenderModel
|
||||
from pygrex.recommender import GroupRecommender
|
||||
from pygrex.utils import SlidingWindowRanker, SlidingWindow, AggregationStrategy
|
||||
|
||||
|
||||
class SlidingWindowExplainer:
|
||||
"""
|
||||
Stratigi, M., Bikakis, N., Stefanidis, K.: Counterfactual explanations for group
|
||||
recommendations. In: Proceedings of the 27th International Workshop on Design,
|
||||
Optimization, Languages and Analytical Processing of Big Data (DOLAP 2025).
|
||||
|
||||
A class that uses a sliding window approach to find counterfactual explanations
|
||||
for group recommendation systems.
|
||||
|
||||
This class helps identify which items, if removed from the group's interaction history,
|
||||
would cause a specific target item to no longer appear in the group recommendations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
data: DataReader,
|
||||
group_handler: GroupInteractionHandler,
|
||||
members: List[Union[str, int]],
|
||||
target_item: Union[str, int],
|
||||
model: RecommenderModel,
|
||||
aggregation_strategy: AggregationStrategy = AggregationStrategy.AVG_PREDICTIONS,
|
||||
window_size=3,
|
||||
):
|
||||
"""
|
||||
Initialize the SlidingWindowExplainer.
|
||||
|
||||
Args:
|
||||
config: Configuration object with model parameters
|
||||
data: DataReader object containing the dataset
|
||||
group_handler: Object that handles group data modifications
|
||||
members: List of user IDs in the group
|
||||
target_item: The item ID for which explanation is sought
|
||||
model: Recommender model to use for predictions,
|
||||
aggregation_strategy: Strategy to aggregate individual recommendations,
|
||||
window_size: Size of the sliding window
|
||||
"""
|
||||
self.cfg = config
|
||||
self.data = data
|
||||
self.group_handler = group_handler
|
||||
self.members = members
|
||||
self.target_item = target_item
|
||||
self.model = model
|
||||
self.aggregation_strategy = aggregation_strategy
|
||||
self.window_size = window_size
|
||||
|
||||
# Results tracking
|
||||
self.explanations_found: Dict[int, Dict] = {}
|
||||
self.calls = 0
|
||||
self.max_calls = 1000
|
||||
self.item_metrics = {}
|
||||
|
||||
def set_sliding_window(self, sliding_window):
|
||||
"""Set the sliding window object if not provided during initialization."""
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
def set_item_metrics(self, metrics: Dict[Union[str, int], Dict[str, float]]):
|
||||
"""Store the pre-calculated metric scores for all items."""
|
||||
self.item_metrics = metrics
|
||||
|
||||
def find_explanation(
|
||||
self,
|
||||
items_rated_by_group: List[Union[str, int]],
|
||||
group_predictions: Dict,
|
||||
top_recommendation: Union[str, int],
|
||||
ranking_weights: Dict[str, float],
|
||||
) -> Dict[int, Dict]:
|
||||
"""
|
||||
Find counterfactual explanations using the full, encapsulated process.
|
||||
|
||||
Args:
|
||||
items_rated_by_group: All items rated by any member of the group.
|
||||
group_predictions: The original individual predictions from the recommender.
|
||||
top_recommendation: The original top recommended item.
|
||||
ranking_weights: The weights from the UI for each ranking component.
|
||||
|
||||
Returns:
|
||||
A dictionary of found explanations, including their justification metrics.
|
||||
"""
|
||||
|
||||
self.calls = 0
|
||||
ranker = SlidingWindowRanker(config={})
|
||||
ranker.set_group_recommender_values(group_predictions, top_recommendation)
|
||||
ranked_items, self.item_metrics = ranker.generate_ranked_items(
|
||||
all_rated_items=items_rated_by_group,
|
||||
data=self.data,
|
||||
group_members=self.members,
|
||||
component_weights=ranking_weights,
|
||||
)
|
||||
|
||||
sliding_window = SlidingWindow(
|
||||
sequence=ranked_items, window_size=self.window_size
|
||||
)
|
||||
|
||||
found = 0
|
||||
while True:
|
||||
# Get the sliding window
|
||||
big_window = sliding_window.get_next_window()
|
||||
|
||||
# Check exit conditions
|
||||
if big_window is None or found > 0 or self.calls >= self.max_calls:
|
||||
break
|
||||
|
||||
# Count calls and windows
|
||||
self.calls += 1
|
||||
|
||||
# Test if removing this window affects recommendations
|
||||
if self._test_window_removal(big_window, self.target_item):
|
||||
# A counterfactual explanation has been found
|
||||
found += 1
|
||||
# Look for minimal subsets within this window
|
||||
self._find_minimal_subset(big_window, self.target_item)
|
||||
|
||||
if found == 0:
|
||||
print("Explanation could not be found")
|
||||
|
||||
return self.explanations_found
|
||||
|
||||
def _test_window_removal(
|
||||
self, item_ids: List[Union[str, int]], original_group_rec: Union[str, int]
|
||||
) -> bool:
|
||||
"""
|
||||
Test if removing the given items affects the group recommendation.
|
||||
|
||||
Args:
|
||||
item_ids: List of item IDs to remove from group interactions
|
||||
original_group_rec: The original recommendation to compare against
|
||||
|
||||
Returns:
|
||||
bool: True if removing these items changes recommendations, False otherwise
|
||||
"""
|
||||
|
||||
# Get new recommendations after removing items
|
||||
group_recommendation = self._get_recommendations_after_removal(item_ids)
|
||||
|
||||
# Check if target item is still in recommendations
|
||||
|
||||
return original_group_rec not in group_recommendation
|
||||
|
||||
def _get_recommendations_after_removal(
|
||||
self, item_ids: List[Union[str, int]], top_n: int = 10
|
||||
) -> Sequence[Union[str, int]]:
|
||||
"""
|
||||
Get group recommendations after removing specified items from interaction history.
|
||||
|
||||
Args:
|
||||
item_ids: List of item IDs to remove from group interactions
|
||||
top_n: Number of top recommendations to return
|
||||
|
||||
Returns:
|
||||
List of recommended item IDs
|
||||
"""
|
||||
# Create modified dataset with items removed
|
||||
changed_data = self.group_handler.create_modified_dataset(
|
||||
original_data=self.data.dataset,
|
||||
group_ids=self.members,
|
||||
item_ids=item_ids,
|
||||
data=self.data,
|
||||
)
|
||||
|
||||
# Create new DataReader and retrain model
|
||||
data_retrained = self._create_data_reader_and_prepare(changed_data)
|
||||
model_retrained = self._retrain_model(data_retrained)
|
||||
|
||||
# Set up recommender with new model and data
|
||||
group_recommender = GroupRecommender(data_retrained)
|
||||
group_recommender.setup_recommendation(
|
||||
model_retrained,
|
||||
self.members,
|
||||
data_retrained,
|
||||
aggregation_strategy=self.aggregation_strategy,
|
||||
)
|
||||
recommendations = group_recommender.get_group_recommendations(top_n)
|
||||
|
||||
if not isinstance(recommendations, list):
|
||||
return []
|
||||
|
||||
return recommendations
|
||||
|
||||
def _create_data_reader_and_prepare(self, changed_data):
|
||||
"""
|
||||
Create and prepare a new DataReader with modified data.
|
||||
|
||||
Args:
|
||||
changed_data: DataFrame with modified dataset
|
||||
|
||||
Returns:
|
||||
DataReader: A new DataReader object with the modified dataset
|
||||
"""
|
||||
data_retrained = DataReader(
|
||||
filepath_or_buffer=None,
|
||||
sep=None,
|
||||
names=None,
|
||||
skiprows=0,
|
||||
dataframe=changed_data,
|
||||
)
|
||||
|
||||
# Fix for potential dataset issue in original code
|
||||
# data_retrained.dataset = data_retrained.dataset.iloc[1:].reset_index(drop=True)
|
||||
|
||||
# Prepare data
|
||||
data_retrained.make_consecutive_ids_in_dataset()
|
||||
data_retrained.binarize(binary_threshold=1)
|
||||
|
||||
return data_retrained
|
||||
|
||||
def _retrain_model(self, data):
|
||||
"""
|
||||
Retrain the recommendation model with modified data.
|
||||
|
||||
Args:
|
||||
data: Prepared DataReader object with modified dataset
|
||||
|
||||
Returns:
|
||||
Retrained model
|
||||
"""
|
||||
self.model.fit(data)
|
||||
return self.model
|
||||
|
||||
def _find_minimal_subset(
|
||||
self, big_window: List[Union[str, int]], original_group_rec: Union[str, int]
|
||||
) -> None:
|
||||
"""
|
||||
Find minimal subset of items that act as counterfactual explanation.
|
||||
|
||||
Args:
|
||||
big_window: List of item IDs to search within
|
||||
original_group_rec: The original recommendation to compare against
|
||||
|
||||
"""
|
||||
found_subset = 0
|
||||
|
||||
# Try combinations of different lengths
|
||||
for length in range(1, len(big_window) + 1):
|
||||
if found_subset > 0 or self.calls > self.max_calls:
|
||||
break
|
||||
|
||||
combinations = itertools.combinations(big_window, length)
|
||||
for item_combo in combinations:
|
||||
if found_subset > 0 or self.calls > self.max_calls:
|
||||
break
|
||||
|
||||
subset_items = list(item_combo)
|
||||
self.calls += 1
|
||||
|
||||
# Get recommendations after removing this subset
|
||||
new_recommendations = self._get_recommendations_after_removal(
|
||||
subset_items
|
||||
)
|
||||
|
||||
# Check if this is a counterfactual explanation
|
||||
if original_group_rec not in new_recommendations:
|
||||
found_subset += 1
|
||||
self._record_explanation(
|
||||
subset_items, original_group_rec, new_recommendations[0]
|
||||
)
|
||||
|
||||
def _record_explanation(
|
||||
self,
|
||||
explanation_items: List[Union[str, int]],
|
||||
original_rec: Union[str, int],
|
||||
new_rec: Union[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Record and display found explanation.
|
||||
|
||||
Args:
|
||||
explanation_items: Items that form the counterfactual explanation
|
||||
original_rec: Original recommendation
|
||||
new_rec: New top recommendation after removing explanation items
|
||||
"""
|
||||
print(
|
||||
f"If the group had not interacted with these items {explanation_items},\n"
|
||||
f"the item of interest {original_rec} would not have appeared on the recommendation list;\n"
|
||||
f"instead, {new_rec} would have been recommended."
|
||||
)
|
||||
# print("")
|
||||
# print(f"Explanation: {explanation_items} : found at call: {self.calls}")
|
||||
|
||||
# Calculate metrics for the explanation
|
||||
item_intensity = self._calculate_item_intensity(explanation_items)
|
||||
user_intensity = self._calculate_user_intensity(explanation_items)
|
||||
explanation_metrics = {
|
||||
item: self.item_metrics.get(item, {}) for item in explanation_items
|
||||
}
|
||||
|
||||
self.explanations_found[self.calls] = {
|
||||
"items": explanation_items,
|
||||
"new_rec": new_rec,
|
||||
"metrics": explanation_metrics,
|
||||
}
|
||||
|
||||
exp_size = len(explanation_items)
|
||||
|
||||
# print(f"{exp_size}\t{self.calls}\t{item_intensity}\t{user_intensity}")
|
||||
|
||||
def _calculate_item_intensity(self, items: List[Union[str, int]]) -> List[float]:
|
||||
"""
|
||||
Calculate average item intensity for explanation items.
|
||||
|
||||
Args:
|
||||
items: List of item IDs in the explanation
|
||||
|
||||
Returns:
|
||||
List of average intensity scores for each item
|
||||
"""
|
||||
|
||||
return self._calculate_average_item_intensity_score(
|
||||
items, self.members, self.data
|
||||
)
|
||||
|
||||
def _calculate_user_intensity(self, items: List[Union[str, int]]) -> List[float]:
|
||||
"""
|
||||
Calculate user intensity score for explanation items.
|
||||
|
||||
Args:
|
||||
items: List of item IDs in the explanation
|
||||
|
||||
Returns:
|
||||
List of intensity scores for each user
|
||||
"""
|
||||
return self._calculate_user_intensity_score(items, self.members, self.data)
|
||||
|
||||
@staticmethod
|
||||
def _calculate_average_item_intensity_score(
|
||||
explanation: List[Union[str, int]],
|
||||
members: List[Union[str, int]],
|
||||
data: DataReader,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate the average item intensity for a counterfactual explanation.
|
||||
|
||||
Average item intensity is defined as the average number of interactions
|
||||
between group members and each item in the explanation.
|
||||
|
||||
Args:
|
||||
explanation: The counterfactual explanation items.
|
||||
members: User IDs of the group members.
|
||||
data: DataReader object containing the dataset and ID mapping methods.
|
||||
|
||||
Returns:
|
||||
list: Average intensity for each item in the explanation.
|
||||
"""
|
||||
internal_group_ids = []
|
||||
# Convert user IDs to internal representation
|
||||
for user_id in members:
|
||||
new_user_id = data.get_new_user_id(user_id)
|
||||
if isinstance(new_user_id, list):
|
||||
if new_user_id: # Check that the list is not empty
|
||||
internal_group_ids.append(int(new_user_id[0]))
|
||||
else:
|
||||
internal_group_ids.append(int(new_user_id))
|
||||
|
||||
group_size = len(members)
|
||||
item_intensities = []
|
||||
|
||||
for item_id in explanation:
|
||||
# Convert item ID to internal representation
|
||||
internal_item_id = data.get_new_item_id(item_id)
|
||||
|
||||
# Count interactions between this item and group members
|
||||
interactions_count = len(
|
||||
data.dataset[
|
||||
(data.dataset.itemId == internal_item_id)
|
||||
& (data.dataset.userId.isin(internal_group_ids))
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate average intensity
|
||||
average_intensity = interactions_count / group_size
|
||||
item_intensities.append(average_intensity)
|
||||
|
||||
return item_intensities
|
||||
|
||||
@staticmethod
|
||||
def _calculate_user_intensity_score(
|
||||
explanation_items: List[Union[str, int]],
|
||||
members: List[Union[str, int]],
|
||||
data: DataReader,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate the interaction intensity for each user based on their interactions with items in an explanation.
|
||||
|
||||
Interaction intensity represents how much a user has interacted with the items in the explanation,
|
||||
normalized by the total number of explanation items.
|
||||
|
||||
Args
|
||||
explanation_items : List of item IDs in the explanation
|
||||
members : List of user IDs to calculate intensity for
|
||||
data : DataReader object containing the dataset and ID mapping methods
|
||||
|
||||
Returns
|
||||
List of interaction intensities for each user (same order as members)
|
||||
Values range from 0 to 1, where:
|
||||
- 0 means no interaction with any explanation item
|
||||
- 1 means interaction with all explanation items
|
||||
|
||||
Notes
|
||||
Intensity is calculated as: (number of user interactions with explanation items) / (number of explanation items)
|
||||
"""
|
||||
# Convert external item IDs to internal IDs
|
||||
internal_item_ids = [
|
||||
data.get_new_item_id(item_id) for item_id in explanation_items
|
||||
]
|
||||
|
||||
user_intensities = []
|
||||
num_explanation_items = len(explanation_items)
|
||||
|
||||
for member in members:
|
||||
# Convert external user ID to internal ID
|
||||
internal_user_id = data.get_new_user_id(member)
|
||||
|
||||
# Count interactions between this user and explanation items
|
||||
user_interactions_count = len(
|
||||
data.dataset[
|
||||
(data.dataset.itemId.isin(internal_item_ids))
|
||||
& (data.dataset.userId == internal_user_id)
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate intensity as proportion of explanation items the user interacted with
|
||||
intensity = user_interactions_count / num_explanation_items
|
||||
user_intensities.append(intensity)
|
||||
|
||||
return user_intensities
|
||||
@@ -0,0 +1,11 @@
|
||||
from .model_based_emf import EMFExplainer
|
||||
from .model_based_als_explain import ALSExplainer
|
||||
from .post_hoc_association_rules import ARPostHocExplainer
|
||||
from .post_hoc_knn import KNNPostHocExplainer
|
||||
|
||||
__all__ = [
|
||||
"EMFExplainer",
|
||||
"ALSExplainer",
|
||||
"ARPostHocExplainer",
|
||||
"KNNPostHocExplainer",
|
||||
]
|
||||
@@ -0,0 +1,49 @@
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class Explainer(ABC):
|
||||
def __init__(self, model, recommendations, data):
|
||||
self.model = model
|
||||
self.recommendations = recommendations
|
||||
self.dataset = data.dataset
|
||||
self.num_items = data.num_item
|
||||
self.num_users = data.num_user
|
||||
self.users = self.dataset.groupby(by="userId")
|
||||
|
||||
def explain_recommendations(self):
|
||||
explanations = []
|
||||
|
||||
with tqdm(
|
||||
total=self.recommendations.shape[0], desc="Computing explanations: "
|
||||
) as pbar:
|
||||
for _, row in self.recommendations.iterrows():
|
||||
explanations.append(
|
||||
self.explain_recommendation_to_user(
|
||||
int(row.userId), int(row.itemId)
|
||||
)
|
||||
)
|
||||
pbar.update()
|
||||
|
||||
self.recommendations["explanations"] = explanations
|
||||
return self.recommendations
|
||||
|
||||
def get_user_items(self, user_id):
|
||||
"""
|
||||
Items Ids rated by a user.
|
||||
:param user_id: the user
|
||||
:return: list
|
||||
"""
|
||||
return self.users.get_group(user_id).itemId.values
|
||||
|
||||
@abstractmethod
|
||||
def explain_recommendation_to_user(
|
||||
self, user_id: int, item_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates an explanation for a single user-item recommendation.
|
||||
This method must be implemented by any subclass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .explainer import Explainer
|
||||
|
||||
|
||||
class ALSExplainer(Explainer):
|
||||
def __init__(self, model, recommendations, data, number_of_contributions=10):
|
||||
super(ALSExplainer, self).__init__(model, recommendations, data)
|
||||
self.number_of_contributions = number_of_contributions
|
||||
|
||||
def explain_recommendation_to_user(self, user_id: int, item_id: int):
|
||||
"""
|
||||
Measuring the contribution of each item to the recommendation.
|
||||
:param model:
|
||||
:param item_id:
|
||||
:param user_id:
|
||||
:return: returns a dataframe with the contribution to the recommendation of each previously interacted with item.
|
||||
"""
|
||||
|
||||
current_interactions = np.zeros(self.num_items)
|
||||
current_interactions[self.get_user_items(user_id)] = 1
|
||||
|
||||
c_u = np.diag(current_interactions)
|
||||
|
||||
y_t = self.model.item_embedding().transpose()
|
||||
temp = np.matmul(y_t, c_u)
|
||||
temp = np.matmul(temp, self.model.item_embedding())
|
||||
temp = temp + np.diag([self.model.reg_term] * self.model.latent_dim)
|
||||
|
||||
if len(self.get_user_items(user_id)) > 1:
|
||||
weight_mtr = np.linalg.inv(temp)
|
||||
else:
|
||||
weight_mtr = np.linalg.pinv(temp)
|
||||
|
||||
temp = np.matmul(self.model.item_embedding(), weight_mtr)
|
||||
|
||||
sim_to_rec_id = temp.dot(self.model.item_embedding()[item_id, :])
|
||||
|
||||
sim_to_rec_id = sim_to_rec_id[self.get_user_items(user_id)]
|
||||
|
||||
contribution = {
|
||||
"item": self.get_user_items(user_id),
|
||||
"contribution": sim_to_rec_id,
|
||||
}
|
||||
contribution = pd.DataFrame(contribution)
|
||||
contribution = contribution.sort_values(by=["contribution"], ascending=False)
|
||||
return {
|
||||
"item": contribution.item[: self.number_of_contributions],
|
||||
"contribution": contribution.contribution[: self.number_of_contributions],
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
from .explainer import Explainer
|
||||
|
||||
|
||||
class EMFExplainer(Explainer):
|
||||
def __init__(self, model, recommendations, data):
|
||||
super(EMFExplainer, self).__init__(model, recommendations, data)
|
||||
|
||||
def explain_recommendation_to_user(self, user_id: int, item_id: int):
|
||||
"""
|
||||
Measuring the contribution of each item to the recommendation.
|
||||
:param user_id:
|
||||
:param item_id: recommendation
|
||||
:return: returns a dataframe with the contribution to the recommendation of each previously interacted with item.
|
||||
"""
|
||||
|
||||
ratings_on_item = self.dataset[self.dataset.itemId == item_id]
|
||||
similar_users = self.model.sim_users[user_id]
|
||||
similar_users_ratings_on_item = ratings_on_item[
|
||||
ratings_on_item.userId.isin(similar_users)
|
||||
]
|
||||
|
||||
explanation_df = similar_users_ratings_on_item.groupby(by="rating").count()
|
||||
explanation = {}
|
||||
|
||||
for index, row in explanation_df.iterrows():
|
||||
explanation[index] = row[0]
|
||||
|
||||
return explanation
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict
|
||||
from mlxtend.preprocessing import TransactionEncoder
|
||||
from mlxtend.frequent_patterns import apriori, association_rules
|
||||
import pandas as pd
|
||||
|
||||
from .explainer import Explainer
|
||||
|
||||
|
||||
class ARPostHocExplainer(Explainer):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
recommendations,
|
||||
data,
|
||||
min_support=0.1,
|
||||
max_len=2,
|
||||
metric="lift",
|
||||
min_threshold=0.1,
|
||||
min_confidence=0.1,
|
||||
min_lift=0.1,
|
||||
):
|
||||
super(ARPostHocExplainer, self).__init__(model, recommendations, data)
|
||||
self.AR = None
|
||||
self.min_support = min_support
|
||||
self.max_len = max_len
|
||||
self.metric = metric
|
||||
self.min_threshold = min_threshold
|
||||
self.min_confidence = min_confidence
|
||||
self.min_lift = min_lift
|
||||
|
||||
self.rules: pd.DataFrame | None = None
|
||||
|
||||
def get_rules_for_getting(self, item_id: int) -> pd.DataFrame:
|
||||
if self.rules is None:
|
||||
self.compute_association_rules()
|
||||
|
||||
if self.rules is not None:
|
||||
return self.rules[self.rules.consequents == item_id]
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def compute_association_rules(self):
|
||||
item_sets = [
|
||||
[item for item in self.dataset[self.dataset.userId == user].itemId]
|
||||
for user in self.dataset.userId.unique()
|
||||
]
|
||||
|
||||
te = TransactionEncoder()
|
||||
te_ary = te.fit(item_sets).transform(item_sets)
|
||||
|
||||
# The te_ary object is a NumPy array, which is a valid input for a DataFrame.
|
||||
# Pylance may raise a false positive here due to incomplete type stubs for mlxtend.
|
||||
df = pd.DataFrame(te_ary.astype(bool), columns=te.columns_) # type: ignore
|
||||
|
||||
frequent_itemsets = apriori(
|
||||
df, min_support=self.min_support, use_colnames=True, max_len=self.max_len
|
||||
)
|
||||
|
||||
rules = association_rules(
|
||||
frequent_itemsets, metric="lift", min_threshold=self.min_threshold
|
||||
)
|
||||
rules = rules[
|
||||
(rules["confidence"] > self.min_confidence)
|
||||
& (rules["lift"] > self.min_lift)
|
||||
]
|
||||
|
||||
rules["consequents"] = rules["consequents"].apply(lambda x: list(x)[0])
|
||||
rules["antecedents"] = rules["antecedents"].apply(lambda x: list(x)[0])
|
||||
|
||||
self.rules = rules[["consequents", "antecedents", "confidence"]]
|
||||
|
||||
def explain_recommendation_to_user(
|
||||
self, user_id: int, item_id: int
|
||||
) -> Dict[str, Any]:
|
||||
user_ratings = self.get_user_items(user_id)
|
||||
rules = self.get_rules_for_getting(item_id)
|
||||
explanations = rules[rules.antecedents.isin(user_ratings)]
|
||||
|
||||
return {"antecedents": set(explanations.antecedents)}
|
||||
@@ -0,0 +1,46 @@
|
||||
from scipy import sparse
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
from typing import Dict, Any
|
||||
|
||||
from .explainer import Explainer
|
||||
|
||||
|
||||
class KNNPostHocExplainer(Explainer):
|
||||
def __init__(self, model, recommendations, data, knn=10):
|
||||
super(KNNPostHocExplainer, self).__init__(model, recommendations, data)
|
||||
|
||||
self.knn = knn
|
||||
# Initialize as an empty dictionary to prevent subscripting None
|
||||
self.knn_items_dict: Dict[int, np.ndarray] = {}
|
||||
|
||||
def get_nn_for_getting(self, item_id: int) -> np.ndarray:
|
||||
# Check if the KNN dictionary has been computed
|
||||
if not self.knn_items_dict:
|
||||
self.compute_knn_items_for_all_items()
|
||||
|
||||
# Return the neighbors for the item, or an empty array if not found
|
||||
return self.knn_items_dict.get(item_id, np.array([]))
|
||||
|
||||
def compute_knn_items_for_all_items(self):
|
||||
ds = np.zeros((self.num_items, self.num_users))
|
||||
# Assuming self.dataset has attributes itemId, userId, and rating
|
||||
ds[self.dataset.itemId, self.dataset.userId] = self.dataset.rating
|
||||
|
||||
ds = sparse.csr_matrix(ds)
|
||||
sim_matrix = cosine_similarity(ds)
|
||||
min_val = sim_matrix.min() - 1
|
||||
|
||||
for i in range(self.num_items):
|
||||
sim_matrix[i, i] = min_val
|
||||
knn_to_item_i = (-sim_matrix[i, :]).argsort()[: self.knn]
|
||||
self.knn_items_dict[i] = knn_to_item_i
|
||||
|
||||
def explain_recommendation_to_user(
|
||||
self, user_id: int, item_id: int
|
||||
) -> Dict[str, Any]:
|
||||
user_ratings = self.get_user_items(user_id)
|
||||
sim_items = self.get_nn_for_getting(item_id)
|
||||
explanations = set(sim_items) & set(user_ratings)
|
||||
|
||||
return {"explanations": explanations}
|
||||
Reference in New Issue
Block a user