Files
py-grex/pygrex/explain/groups/lore4groups_explainer.py
T
2026-05-22 10:02:10 +02:00

732 lines
28 KiB
Python

import pandas as pd
import numpy as np
import re
import logging
import traceback
from collections import Counter
from typing import Dict, Set, List, Optional, Any, Tuple, Union
from sklearn.tree import DecisionTreeClassifier, _tree
ItemId = Union[str, int]
UserId = Union[str, int]
FactualRule = List[str]
CounterfactualSet = List[List[str]]
Explanation = Tuple[Optional[FactualRule], Optional[CounterfactualSet]]
class LORE4GroupsExplainer:
"""
Enhanced LORE4Groups explainer that incorporates genre information
and stores decision trees for visualization
"""
def __init__(
self,
item_profiles: Dict[str, Set[str]],
item_label_matrix: pd.DataFrame,
config: Dict,
genre_profiles: Optional[Dict[str, Set[str]]] = None,
):
self.item_profiles = {str(k): v for k, v in item_profiles.items()}
self.item_label_matrix = item_label_matrix
self.params = config["explainer"]["lore4groups"]
# NEW: Store genre information
self.genre_profiles = (
{str(k): v for k, v in genre_profiles.items()} if genre_profiles else {}
)
all_columns = item_label_matrix.columns.tolist()
self.all_labels = [col for col in all_columns if col != "like"]
# Add 'like' back for target variable access (but not as feature)
if "like" in all_columns:
self.all_labels.append("like")
def _enhanced_jaccard_similarity(self, item1_id: ItemId, item2_id: ItemId) -> float:
"""Enhanced Jaccard similarity that considers both tags and genres"""
# Get regular tags
tags1 = self.item_profiles.get(str(item1_id), set())
tags2 = self.item_profiles.get(str(item2_id), set())
# Get genres and add them as features
genres1 = self.genre_profiles.get(str(item1_id), set())
genres2 = self.genre_profiles.get(str(item2_id), set())
# Combine tags and genres for enhanced similarity
features1 = tags1.union({f"genre_{g.lower()}" for g in genres1})
features2 = tags2.union({f"genre_{g.lower()}" for g in genres2})
if not features1 or not features2:
return 0.0
union_len = len(features1.union(features2))
intersection_len = len(features1.intersection(features2))
return intersection_len / union_len if union_len > 0 else 0.0
def _jaccard_similarity(self, item1_id: ItemId, item2_id: ItemId) -> float:
"""Original jaccard similarity (kept for compatibility)"""
tags1 = self.item_profiles.get(str(item1_id), set())
tags2 = self.item_profiles.get(str(item2_id), set())
if not tags1 or not tags2:
return 0.0
union_len = len(tags1.union(tags2))
return len(tags1.intersection(tags2)) / union_len if union_len > 0 else 0.0
def _get_enhanced_similar_examples(
self,
user_id_consecutive: UserId,
target_item_id: ItemId,
user_hist: Set[ItemId],
dataset: pd.DataFrame,
model=None,
data_reader=None,
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
"""Enhanced version that returns both DataFrame and metadata for visualization"""
# 1. Find all similar items using enhanced similarity
similarities = [
(seen_id, self._enhanced_jaccard_similarity(target_item_id, seen_id))
for seen_id in user_hist
]
similarities = sorted(similarities, key=lambda x: x[1], reverse=True)
sim_th = self.params.get("similarity_threshold", 0.0)
top_similar_items_str = {
item[0]
for item in similarities[: self.params["n_similar_for_tree"]]
if item[1] >= sim_th
}
if not top_similar_items_str:
return pd.DataFrame(), {}
# 2. Build the local dataset
top_similar_items_int = [int(i) for i in top_similar_items_str]
# Get existing ratings for similar items
local_df = dataset[
(dataset["userId"] == user_id_consecutive)
& (dataset["itemId"].isin(top_similar_items_int))
].copy()
rated_items = set(local_df["itemId"])
items_to_predict = [
item for item in top_similar_items_int if item not in rated_items
]
# Add predictions for unrated items
if model and data_reader and items_to_predict:
try:
orig_user_id = data_reader.get_original_user_id(
int(user_id_consecutive)
)
predicted_ratings = []
for item_id_consecutive in items_to_predict:
orig_item_id = data_reader.get_original_item_id(
int(item_id_consecutive)
)
pred = model.predict(orig_user_id, orig_item_id)
predicted_ratings.append(
{
"userId": user_id_consecutive,
"itemId": item_id_consecutive,
"rating": float(pred),
}
)
if predicted_ratings:
pred_df = pd.DataFrame(predicted_ratings)
local_df = pd.concat([local_df, pred_df], ignore_index=True)
except Exception:
traceback.print_exc()
# Check minimum samples requirement
if len(local_df) < 2:
return pd.DataFrame(), {}
# 3. Apply thresholding with fallbacks
rating_threshold = self.params["rating_threshold_for_like"]
threshold_info = {
"was_overridden": False,
"original_threshold": rating_threshold,
"final_threshold": rating_threshold,
}
local_df["like"] = (local_df["rating"] >= rating_threshold).astype(int)
# Apply fallback thresholds if needed
like_counts = local_df["like"].value_counts()
if len(like_counts) < 2:
# Try mean-based threshold
mean_rating = local_df["rating"].mean()
local_df["like"] = (local_df["rating"] >= mean_rating).astype(int)
threshold_info["was_overridden"] = True
threshold_info["final_threshold"] = mean_rating
like_counts = local_df["like"].value_counts()
if len(like_counts) < 2:
return pd.DataFrame(), {}
# Check for severe imbalance (>90% one class)
min_class_ratio = like_counts.min() / len(local_df)
if min_class_ratio < 0.1:
if like_counts.min() < 2:
return pd.DataFrame(), {}
# 4. Construct the enhanced feature matrix (including genres)
feature_labels = [label for label in self.all_labels if label != "like"]
examples = []
genre_features_used = set()
for idx, row in local_df.iterrows():
item_id = str(int(row["itemId"]))
tags = self.item_profiles.get(item_id, set())
genres = self.genre_profiles.get(item_id, set())
# Create base example with target variables
example = {
"movie_id": item_id,
"rating": row["rating"],
"like": int(row["like"]),
}
# Add tag features (excluding 'like')
for label in feature_labels:
example[label] = 1 if label in tags else 0
# Add genre features dynamically
for genre in genres:
genre_feature = f"genre_{genre.lower()}"
example[genre_feature] = 1
genre_features_used.add(genre_feature)
# Also add to feature_labels if not already there
if genre_feature not in feature_labels:
feature_labels.append(genre_feature)
examples.append(example)
# Ensure all examples have all genre features
for example in examples:
for genre_feature in genre_features_used:
if genre_feature not in example:
example[genre_feature] = 0
final_df = pd.DataFrame(examples)
# Final validation
if final_df["like"].nunique() < 2:
return pd.DataFrame(), {}
# Prepare metadata for visualization
metadata = {
"feature_labels": [label for label in feature_labels if label != "like"],
"genre_features": list(genre_features_used),
"similarity_scores": dict(similarities[:5]), # Top 5 similarities
"target_item_genres": self.genre_profiles.get(str(target_item_id), set()),
"rating_threshold": threshold_info["final_threshold"],
"threshold_info": threshold_info,
}
return final_df, metadata
def _get_factual_path_for_item(
self,
clf: DecisionTreeClassifier,
x_item: pd.DataFrame,
metadata: Dict[str, Any],
) -> Optional[List[str]]:
"""
Traces the specific path an item takes through the decision tree
and returns the corresponding factual rule set.
"""
feature_labels = metadata.get("feature_labels", [])
if not feature_labels:
return None
# 1. Get the sequence of nodes the item travels through
node_indicator = clf.decision_path(x_item)
node_index = node_indicator.indices[ # type: ignore
node_indicator.indptr[0] : node_indicator.indptr[ # type: ignore
1
]
]
rules = []
tree = clf.tree_
# 2. Iterate through the path to build the rules
# We stop at the second to last node because the last one is the leaf
for i in range(len(node_index) - 1):
node_id = node_index[i]
child_node_id = node_index[i + 1]
# Ensure this is not a leaf node
if tree.feature[node_id] != _tree.TREE_UNDEFINED: # type: ignore
feature_name = feature_labels[tree.feature[node_id]] # type: ignore
threshold = tree.threshold[node_id] # type: ignore
# 3. Determine if the path went left or right to form the rule
if child_node_id == tree.children_left[node_id]: # type: ignore
# Path went left (True condition for <= threshold)
rule = f"{feature_name} <= {threshold:.2f}"
else:
# Path went right (False condition for <= threshold)
rule = f"{feature_name} > {threshold:.2f}"
# Use the same enhanced formatting as before for consistency
if feature_name.startswith("genre_"):
genre_name = feature_name.replace("genre_", "").title()
if child_node_id == tree.children_left[node_id]: # type: ignore
rules.append(f"Does NOT have genre: `{genre_name}`")
else:
rules.append(f"Has genre: `{genre_name}`")
else:
rules.append(rule)
return rules if rules else None
def _train_enhanced_decision_tree(
self,
user_id_consecutive: UserId,
item_id: ItemId,
user_hist: Set[ItemId],
dataset: pd.DataFrame,
model=None,
data_reader=None,
) -> Tuple[Optional[DecisionTreeClassifier], Dict[str, Any]]:
"""Enhanced tree training that returns both classifier and metadata"""
df_examples, metadata = self._get_enhanced_similar_examples(
user_id_consecutive, item_id, user_hist, dataset, model, data_reader
)
if df_examples.empty:
return None, {}
like_counts = df_examples["like"].value_counts()
if len(like_counts) < 2 or like_counts.min() < 2:
return None, {}
feature_labels = metadata.get("feature_labels", [])
X = df_examples[feature_labels]
y = df_examples["like"]
# Verify feature matrix has variance
feature_variances = X.var()
if (feature_variances == 0).all():
return None, {}
clf = DecisionTreeClassifier(
max_depth=5, # Slightly deeper to accommodate genre features
min_samples_split=max(4, len(df_examples) // 4),
min_samples_leaf=2,
random_state=42,
class_weight="balanced",
)
try:
clf.fit(X, y)
# Enhanced feature importance analysis
feature_importance = list(zip(feature_labels, clf.feature_importances_))
important_features = [
(f, imp) for f, imp in feature_importance if imp > 0.001
]
genre_important_features = [
(f, imp) for f, imp in important_features if f.startswith("genre_")
]
# Add classifier and feature info to metadata
metadata.update(
{
"classifier": clf,
"feature_importance": dict(feature_importance),
"important_features": important_features,
"genre_important_features": genre_important_features,
"training_data_size": len(df_examples),
"class_distribution": like_counts.to_dict(),
}
)
return clf, metadata
except Exception as _:
return None, {}
def _get_enhanced_explanation_path(
self,
clf: DecisionTreeClassifier,
x_item: pd.DataFrame,
metadata: Dict[str, Any],
) -> Optional[List[str]]:
"""Enhanced explanation path that provides better rule descriptions"""
if 1 not in clf.classes_:
return None
leaf_id = clf.apply(x_item)[0] # type: ignore
class_index = np.where(clf.classes_ == 1)[0]
if not class_index.size or clf.tree_.value[leaf_id][0][class_index[0]] == 0: # type: ignore
return None
node_indicator = clf.decision_path(x_item)
node_index = node_indicator.indices[ # type: ignore
node_indicator.indptr[0] : node_indicator.indptr[ # type: ignore
1
]
]
rules = []
feature_labels = metadata.get("feature_labels", [])
for i in range(len(node_index) - 1): # Exclude leaf node
node_id = node_index[i]
next_node_id = node_index[i + 1]
if clf.tree_.feature[node_id] != _tree.TREE_UNDEFINED: # type: ignore
feature_name = feature_labels[clf.tree_.feature[node_id]] # type: ignore
threshold = clf.tree_.threshold[node_id] # type: ignore
# Enhanced rule formatting based on feature type
if feature_name.startswith("genre_"):
genre_name = feature_name.replace("genre_", "").title()
if next_node_id == clf.tree_.children_left[node_id]: # type: ignore
rules.append(f"Does NOT have genre: `{genre_name}`")
else:
rules.append(f"Has genre: `{genre_name}`")
else:
# Regular tag features
if next_node_id == clf.tree_.children_left[node_id]: # type: ignore
rules.append(f"{feature_name} <= {threshold}")
else:
rules.append(f"{feature_name} > {threshold}")
return rules
def _generate_enhanced_individual_explanation(
self, clf: DecisionTreeClassifier, item_id: ItemId, metadata: Dict[str, Any]
) -> Optional[Explanation]:
"""Enhanced individual explanation generation"""
if str(item_id) not in self.item_label_matrix.index:
return None
x_item_full = self.item_label_matrix.loc[[str(item_id)]]
feature_labels = metadata.get("feature_labels", [])
try:
# For genre features, we need to dynamically add them to the item
item_genres = self.genre_profiles.get(str(item_id), set())
# Create enhanced item representation
enhanced_item_data = x_item_full.copy()
# Add genre features
for genre in item_genres:
genre_feature = f"genre_{genre.lower()}"
if genre_feature in feature_labels:
enhanced_item_data[genre_feature] = 1
# Ensure all genre features exist (set to 0 if not present)
for feature in feature_labels:
if (
feature.startswith("genre_")
and feature not in enhanced_item_data.columns
):
enhanced_item_data[feature] = 0
# Select only the features used in training
x_item = enhanced_item_data[feature_labels]
except KeyError as _:
return None
# Get enhanced factual rule
# factual_rule = self._get_enhanced_explanation_path(clf, x_item, metadata)
factual_rule = self._get_factual_path_for_item(clf, x_item, metadata)
if not factual_rule:
return None
# Get counterfactuals (reuse existing method)
counterfactual_set = self._get_counterfactual_paths(clf, x_item)
if not counterfactual_set:
return None
return (factual_rule, counterfactual_set)
def _get_counterfactual_paths(
self, clf: DecisionTreeClassifier, x_item: pd.DataFrame
) -> Optional[CounterfactualSet]:
"""Original counterfactual path method (kept for compatibility)"""
tree = clf.tree_
paths = []
def find_paths(node_id, current_path):
if tree.feature[node_id] == _tree.TREE_UNDEFINED: # type: ignore
class_index = np.where(clf.classes_ == 0)[0]
if class_index.size and tree.value[node_id][0][class_index[0]] > 0:
paths.append(list(current_path))
return
feature_idx = tree.feature[node_id] # type: ignore
threshold = tree.threshold[node_id] # type: ignore
current_path.append((feature_idx, "<=", threshold))
find_paths(tree.children_left[node_id], current_path) # type: ignore
current_path.pop()
current_path.append((feature_idx, ">", threshold))
find_paths(tree.children_right[node_id], current_path) # type: ignore
current_path.pop()
find_paths(0, [])
if not paths:
return None
min_nf = float("inf")
counterfactuals = []
for path in paths:
nf = 0
for feature_idx, op, threshold in path:
if feature_idx < len(x_item.columns):
item_val = x_item.iloc[0, feature_idx]
if not (
(op == "<=" and item_val <= threshold)
or (op == ">" and item_val > threshold)
):
nf += 1
if nf < min_nf:
min_nf = nf
counterfactuals = [path]
elif nf == min_nf:
counterfactuals.append(path)
# Enhanced counterfactual formatting
formatted_counterfactuals = []
for cf_path in counterfactuals:
formatted_path = []
for idx, op, _ in cf_path:
if idx < len(x_item.columns):
feature_name = x_item.columns[idx]
if feature_name.startswith("genre_"):
genre_name = feature_name.replace("genre_", "").title()
if op == "<=":
formatted_path.append(
f"Does NOT have genre: `{genre_name}`"
)
else:
formatted_path.append(f"Has genre: `{genre_name}`")
else:
formatted_path.append(f"{feature_name} {op} 0.5")
if formatted_path:
formatted_counterfactuals.append(formatted_path)
return formatted_counterfactuals if formatted_counterfactuals else None
def _aggregate_factual_rules(
self, individual_explanations: Dict[UserId, List[str]], total_group_size: int
) -> Dict[str, List[str]]:
"""
Aggregates individual factual rules into a group consensus by finding
the rules supported by a majority of members.
"""
# Flatten the list of all rules from all users into a single list
all_rules_flat = [
rule
for rules_list in individual_explanations.values()
for rule in rules_list
]
if not all_rules_flat:
return {"unanimous": [], "majority": [], "minority": []}
# Count the occurrences of each rule
rule_counts = Counter(all_rules_flat)
majority_threshold = (total_group_size // 2) + 1 if total_group_size > 1 else 1
minority_threshold = 1
cleaned_rules_set = self._clean_contradictory_rules(set(rule_counts.keys()))
categorized_rules = {"unanimous": [], "majority": [], "minority": []}
for rule in sorted(list(cleaned_rules_set)):
count = rule_counts[rule]
rule_with_support = f"{rule} ({count}/{total_group_size} members)"
if count == total_group_size:
categorized_rules["unanimous"].append(rule_with_support)
elif count >= majority_threshold:
categorized_rules["majority"].append(rule_with_support)
elif count >= minority_threshold:
categorized_rules["minority"].append(rule_with_support)
return categorized_rules
def _clean_contradictory_rules(self, rules_set: Set[str]) -> Set[str]:
"""Enhanced contradiction cleaning that handles genre rules"""
conditions_by_attr = {}
for rule in rules_set:
# Handle genre rules
if "Has genre:" in rule or "Does NOT have genre:" in rule:
genre_match = re.search(r"`([^`]+)`", rule)
if genre_match:
genre = genre_match.group(1)
attr = f"genre_{genre}"
op = "has" if "Has genre:" in rule else "not_has"
conditions_by_attr.setdefault(attr, set()).add(op)
else:
# Handle regular rules
match = re.match(r"(.+?)\s*([<>]=?)\s*(\d+\.?\d*)", rule)
if match:
attr, op, _ = match.groups()
conditions_by_attr.setdefault(attr.strip(), set()).add(op)
# Find contradictory attributes
invalid_attrs = set()
for attr, ops in conditions_by_attr.items():
if attr.startswith("genre_"):
# Genre contradiction: has and not_has same genre
if "has" in ops and "not_has" in ops:
invalid_attrs.add(attr)
else:
# Numerical contradiction: <= and >
if any(op in ops for op in ["<=", "<"]) and any(
op in ops for op in [">", ">="]
):
invalid_attrs.add(attr)
# Remove contradictory rules
clean_rules = set()
for rule in rules_set:
is_invalid = False
for invalid_attr in invalid_attrs:
if invalid_attr.startswith("genre_"):
genre = invalid_attr.replace("genre_", "")
if f"`{genre}`" in rule:
is_invalid = True
break
else:
if invalid_attr in rule:
is_invalid = True
break
if not is_invalid:
clean_rules.add(rule)
return clean_rules
def find_explanation(
self,
recommended_items: List[ItemId],
members: List[UserId],
user_hist: Dict[UserId, Set[ItemId]],
dataset: pd.DataFrame,
model=None,
data_reader=None,
) -> Dict[str, Any]:
"""Enhanced explanation finding with tree storage for visualization"""
if data_reader is None:
raise ValueError(
"A 'data_reader' object must be provided to find explanations."
)
detailed_explanations = {}
explainable_count = 0
if not recommended_items:
return {"fidelity": 0.0, "details": {}}
for item_id in recommended_items:
all_individual_rules = {}
all_counterfactuals = {}
stored_classifiers = {} # Store classifiers for visualization
stored_metadata = {} # Store metadata for visualization
representative_decision_path = None
threshold_info_for_item = None
for user_id in members:
user_id_consecutive = data_reader.get_new_user_id(user_id)
clf, metadata = self._train_enhanced_decision_tree(
user_id_consecutive,
item_id,
user_hist.get(user_id, set()),
dataset,
model,
data_reader,
)
if clf and metadata:
if threshold_info_for_item is None and "threshold_info" in metadata:
threshold_info_for_item = metadata["threshold_info"]
explanation = self._generate_enhanced_individual_explanation(
clf, item_id, metadata
)
if explanation:
r, phi = explanation
all_individual_rules[user_id] = r
all_counterfactuals[user_id] = phi
if representative_decision_path is None:
representative_decision_path = r
# Store for visualization (use first successful classifier)
if not stored_classifiers:
stored_classifiers[user_id] = clf
stored_metadata[user_id] = metadata
total_members_in_group = len(members)
factual_set = self._aggregate_factual_rules(
all_individual_rules, total_members_in_group
)
if representative_decision_path and factual_set:
explainable_count += 1
# Enhanced detailed explanations with visualization data
item_explanation = {
"decision_path": representative_decision_path,
"group_factual_rule": factual_set,
"individual_counterfactuals": all_counterfactuals,
}
if threshold_info_for_item:
item_explanation["threshold_info"] = threshold_info_for_item
# Add visualization data if available
if stored_classifiers:
user_id_for_viz = list(stored_classifiers.keys())[0]
item_explanation.update(
{
"decision_tree": stored_classifiers[user_id_for_viz],
"feature_names": stored_metadata[user_id_for_viz].get(
"feature_labels", []
),
"tree_metadata": stored_metadata[user_id_for_viz],
"item_genres": self.genre_profiles.get(str(item_id), set()),
}
)
detailed_explanations[item_id] = item_explanation
fidelity = (
explainable_count / len(recommended_items) if recommended_items else 0.0
)
group_explanations = {
"fidelity": fidelity,
"details": detailed_explanations,
}
logging.info(
f"Enhanced fidelity for {members}: {fidelity:.3f} ({explainable_count}/{len(recommended_items)})"
)
return group_explanations