Files
py-grex/pages/4_Explanation_Evaluation.py
T
2026-05-22 10:02:10 +02:00

2099 lines
81 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import re
import pickle
from io import StringIO
import streamlit as st
import pandas as pd
import numpy as np
from collections import Counter
from typing import Dict, Iterable, Optional, Set, Union
from contextlib import redirect_stdout, redirect_stderr
import plotly.graph_objects as go
# Library Imports
from pygrex.config import cfg
from pygrex.evaluator import ExplanationEvaluator
from pygrex.explain import (
LORE4GroupsExplainer,
SlidingWindowExplainer,
RuleBasedGroupRecExplainer,
)
# Required to load the config for the explainer
# Page and State Configuration
st.set_page_config(
page_title="Explanation & Evaluation",
page_icon="💬",
layout="wide",
)
st.title("💬 Explanation & Evaluation")
# Session State Checks
if not st.session_state.get("recommended_items", False):
st.warning(
"Please generate recommendations on the 'Group Recommendation' page first."
)
st.stop()
st.markdown("")
# Retrieve objects from session state
data_reader = st.session_state.data_reader
group_handler = st.session_state.group_handler
model = st.session_state.trained_model
group_recommender = st.session_state.group_recommender
recommended_items = st.session_state.recommended_items
group_members = group_recommender.get_group_members()
aggregation_strategy = group_recommender.get_aggregation_strategy()
# Caching for Expensive Functions
@st.cache_data
def load_cached_data_rules(min_support, min_confidence, rating_threshold):
"""
Loads pre-computed association rules from the cached_rules folder.
Returns the rules if found, None otherwise.
"""
# Format the filename according to the pattern
filename = f"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}"
# Try different file extensions
cached_rules_dir = "cached_rules"
possible_extensions = [".pkl", ".pickle", ".json"]
for ext in possible_extensions:
filepath = os.path.join(cached_rules_dir, filename + ext)
if os.path.exists(filepath):
try:
data = None
if ext in [".pkl", ".pickle"]:
with open(filepath, "rb") as f:
data = pickle.load(f)
elif ext == ".json":
import json
with open(filepath, "r") as f:
data = json.load(f)
return data
except Exception as e:
st.error(f"Error loading cached rules from {filepath}: {e}")
continue
return None
@st.cache_data
def get_user_history(rating_threshold):
"""
Generates the user interaction history based only on the rating threshold.
The keys of the returned dictionary are the ORIGINAL user IDs.
"""
df_filtered = data_reader.dataset[data_reader.dataset["rating"] >= rating_threshold]
# Group by the 'userId' column (which contains the new, consecutive IDs)
history_by_new_id = df_filtered.groupby("userId")["itemId"].apply(set).to_dict()
# Create the final dictionary mapping original user IDs to sets of new item IDs
history_by_original_id = {}
for new_id, item_set in history_by_new_id.items():
try:
original_id = data_reader.get_original_user_id(int(new_id))
# The explainer needs the item IDs to be strings to match the rules
history_by_original_id[original_id] = {str(item) for item in item_set}
except (ValueError, KeyError):
continue
return history_by_original_id
@st.cache_data
def summarize_explanation_rules(
explained_rules_info: list, all_rules_df: pd.DataFrame
) -> pd.DataFrame:
"""
Aggregates a list of rule dictionaries by looking up their full details
in the main rules DataFrame.
"""
if not explained_rules_info:
return pd.DataFrame()
# Get a list of all the antecedent frozensets that were found by the explainer
found_antecedents = [rule["antecedent"] for rule in explained_rules_info]
# Use collections.Counter to efficiently count the frequency of each unique antecedent
antecedent_counts = Counter(found_antecedents)
# Convert the counter object to a DataFrame: [antecedent, num_rules]
summary = pd.DataFrame(
antecedent_counts.items(), columns=["antecedent", "num_rules"]
)
# Filter the main rules DataFrame to get the full details for our found antecedents
metrics_df = all_rules_df[all_rules_df["antecedents"].isin(summary["antecedent"])]
# Calculate the average metrics for each antecedent
avg_metrics = (
metrics_df.groupby("antecedents")
.agg(avg_confidence=("confidence", "mean"), avg_lift=("lift", "mean"))
.reset_index()
)
# Merge the counts and the average metrics together
final_summary = pd.merge(
summary, avg_metrics, left_on="antecedent", right_on="antecedents"
)
# Clean up and sort the final result
final_summary = final_summary.sort_values(by="num_rules", ascending=False)
final_summary = final_summary.rename(
columns={"antecedent": "because_they_interacted_with"}
)
# Select and reorder the final columns to be displayed
return final_summary[
["because_they_interacted_with", "num_rules", "avg_confidence", "avg_lift"]
]
def generate_pills(id_list, max_display=10, pill_type="default"):
"""Generates styled HTML pills for a list of IDs with theme adaptation."""
# Different styles for different pill types with better contrast
if pill_type == "users":
base_color = "#4285f4" # Google Blue
hover_color = "#5a95f5"
shadow_color = "rgba(66, 133, 244, 0.3)"
elif pill_type == "items":
base_color = "#ff6b35" # Vibrant Orange
hover_color = "#ff8659"
shadow_color = "rgba(255, 107, 53, 0.3)"
else:
base_color = "#6366f1" # Indigo
hover_color = "#8b5cf6"
shadow_color = "rgba(99, 102, 241, 0.3)"
pill_style = f"""
display: inline-block;
padding: 0.35rem 0.75rem;
font-size: 0.875rem;
font-weight: 600;
color: #FFFFFF;
background: linear-gradient(135deg, {base_color} 0%, {hover_color} 100%);
border-radius: 1.5rem;
margin: 0.25rem 0.15rem;
border: none;
box-shadow: 0 2px 8px {shadow_color};
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
cursor: pointer;
user-select: none;
"""
# Enhanced hover and animation styles
hover_style = f"""
<style>
.pill-container {{
line-height: 1.6;
margin: 0.5rem 0;
}}
.pill-container span:hover {{
transform: translateY(-2px) scale(1.05);
box-shadow: 0 6px 16px {shadow_color};
background: linear-gradient(135deg, {hover_color} 0%, {base_color} 100%);
}}
.pill-container span:active {{
transform: translateY(0) scale(1.02);
transition: all 0.1s ease;
}}
</style>
"""
# Show only the specified number, add "..." if there are more
display_list = id_list[:max_display]
pills_html = "".join(
[
f'<span style="{pill_style}" title="ID: {id}">{id}</span>'
for id in display_list
]
)
if len(id_list) > max_display:
more_count = len(id_list) - max_display
more_style = (
pill_style.replace(base_color, "#64748b")
.replace(hover_color, "#94a3b8")
.replace(shadow_color, "rgba(100, 116, 139, 0.2)")
)
more_pill = f'<span style="{more_style}" title="View all {len(id_list)} items">+{more_count} more</span>'
pills_html += more_pill
return hover_style + f'<div class="pill-container">{pills_html}</div>'
# A dictionary to get a friendlier name for the model class
model_name = type(model).__name__
st.markdown(
"""
<style>
/* This style only applies to hr tags inside an element with the class "compact-list" */
.compact-list .stElementContainer {
margin-bottom: 1rem !important;
}
</style>
""",
unsafe_allow_html=True,
)
# Context and Summary Section
with st.expander("🎯 **Recommendation Context & Results**", expanded=True):
st.markdown(
"""
<style>
/* Mobile tweaks */
@media (max-width: 768px) {
.metric-container {
display: flex !important;
flex-direction: column !important;
align-items: flex-start;
gap: 1.5rem !important;
margin: 0 !important;
padding: 0 !important;
}
.custom-metric {
margin: 0 !important;
}
/* Remove gap before the first metric */
.metric-container > .custom-metric:first-child {
margin-top: 0 !important;
}
/* Remove gap after the last metric */
.metric-container > .custom-metric:last-child {
margin-bottom: 0 !important;
}
.custom-metric-label {
font-size: 0.9rem !important;
margin-bottom: 0.3rem !important;
}
.custom-metric-value,
.model-badge,
.strategy-badge {
font-size: 1.4rem !important;
line-height: 1.3 !important;
}
}
/* Desktop (force grid) */
@media (min-width: 769px) {
.metric-container {
display: grid !important;
grid-template-columns: repeat(4, 1fr) !important;
gap: 0.5rem ;
margin: 0 !important;
padding: 0 !important;
}
}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<style>
.metric-container {
padding: 1rem 0;
margin: 0.5rem 0;
}
.section-header {
font-size: 1rem;
font-weight: 600;
color: var(--text-color, #374151);
margin-bottom: 0.75rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.custom-metric {
display: flex;
flex-direction: column;
gap: 0.25rem;
}
.custom-metric-label {
font-size: 0.875rem;
font-weight: 600;
color: var(--text-color-secondary, #6b7280);
margin: 0;
padding: 0;
line-height: 1.25;
}
.custom-metric-value {
font-size: 2rem;
font-weight: 700;
margin: 0;
padding: 0;
line-height: 1;
}
.model-badge {
display: inline-flex;
align-items: center;
gap: 0.5rem;
font-weight: 600;
font-size: 2rem;
line-height: 1;
margin: 0;
padding: 0;
}
.strategy-badge {
display: inline-flex;
align-items: center;
gap: 0.5rem;
font-weight: 600;
font-size: 2rem;
line-height: 1;
margin: 0;
padding: 0;
}
.metric-description {
font-size: 0.75rem;
color: var(--text-color-secondary, #9ca3af);
margin-top: 0.25rem;
line-height: 1.25;
}
</style>
""",
unsafe_allow_html=True,
)
st.markdown('<div class="metric-container">', unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns([1.2, 1.5, 1, 1])
with col1:
st.markdown(
f"""
<div class="custom-metric">
<div class="custom-metric">🧠 Model</div>
<div class="model-badge"> {model_name}</div>
</div>
""",
unsafe_allow_html=True,
)
with col2:
strategy_display = aggregation_strategy.name.replace("_", " ").title()
st.markdown(
f"""
<div class="custom-metric">
<div class="custom-metric">📊 Aggregation Strategy</div>
<div class="strategy-badge">{strategy_display}</div>
</div>
""",
unsafe_allow_html=True,
)
with col3:
st.markdown(
f"""
<div class="custom-metric">
<div class="custom-metric">👥 Group Size</div>
<div class="custom-metric-value">{len(group_members)}</div>
</div>
""",
unsafe_allow_html=True,
)
with col4:
st.markdown(
f"""
<div class="custom-metric">
<div class="custom-metric">📝 Items Recommended</div>
<div class="custom-metric-value">{len(recommended_items)}</div>
</div>
""",
unsafe_allow_html=True,
)
st.markdown("</div>", unsafe_allow_html=True)
st.markdown("")
list_col1, list_col2 = st.columns(2)
with list_col1:
st.markdown(
'<div class="custom-metric">👥 Group Members</div>', unsafe_allow_html=True
)
pills_html = generate_pills(group_members, max_display=18, pill_type="users")
st.markdown(pills_html, unsafe_allow_html=True)
with list_col2:
st.markdown(
'<div class="custom-metric">🎯 Top Recommendations</div>',
unsafe_allow_html=True,
)
pills_html = generate_pills(
recommended_items, max_display=18, pill_type="items"
)
st.markdown(pills_html, unsafe_allow_html=True)
explainer_options = [
"Counterfactual Explanation (Sliding Window)",
"Rule-Based Explanation (EXPGRS)",
"Local Model-Agnostic (LORE4GroupRS)",
]
chosen_explainer = st.selectbox(
"**Select an explanation method:**",
options=explainer_options,
help="Choose the type of explanation you want to generate for the recommendation.",
)
st.markdown("")
# Content for "Counterfactual Explanation (Sliding Window)"
if chosen_explainer == explainer_options[0]:
st.header("Counterfactual Explanation (Sliding Window)")
st.markdown(
"""
This method answers the question: *"Which minimal set of items, if removed from the group's history, would cause the recommended item to disappear from the list?"* It first ranks the group's previously seen items by a composite score and then iteratively removes them to find an explanation.
"""
)
st.markdown("##### Configuration")
col1, col2 = st.columns(2)
with col1:
target_item = st.selectbox(
"Choose a recommended item to explain:",
options=recommended_items,
key="sw_target_item",
)
with col2:
window_size = st.slider("Sliding Window Size", 1, 10, 3, key="sw_window_size")
# Component Weights for Item Ranking
with st.expander("Configure Item Ranking Weights"):
st.write(
"These weights determine how to rank the items from the group's history before attempting to remove them."
)
weights = {}
c1, c2, c3, c4, c5 = st.columns(5)
with c1:
weights["popularity"] = st.slider(
"Popularity", 0.0, 2.0, 1.0, 0.1, key="w_pop"
)
with c2:
weights["intensity"] = st.slider(
"Intensity", 0.0, 2.0, 1.0, 0.1, key="w_int"
)
with c3:
weights["rating"] = st.slider("Rating", 0.0, 2.0, 1.0, 0.1, key="w_rat")
with c4:
weights["relevance"] = st.slider(
"Relevance", 0.0, 2.0, 1.0, 0.1, key="w_rel"
)
with c5:
weights["trend"] = st.slider("Trend", 0.0, 2.0, 1.0, 0.1, key="w_tre")
if st.button("Generate Counterfactual Explanation", key="sw_button"):
with st.spinner(
"Ranking items and running Sliding Window explainer... This may take a while."
):
try:
# 1. Get all items rated by the group
items_rated_by_group = (
group_handler.get_rated_items_by_all_group_members(
group=group_members, original_data=data_reader
)
)
explainer = SlidingWindowExplainer(
config=None,
data=data_reader,
group_handler=group_handler,
members=group_members,
target_item=target_item,
aggregation_strategy=aggregation_strategy,
model=model,
window_size=window_size,
)
# 5. Find explanations and capture the results
stdout_buffer = StringIO()
stderr_buffer = StringIO()
# Use context managers to redirect both streams to our text buffers
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
explanations = explainer.find_explanation(
items_rated_by_group=items_rated_by_group,
group_predictions=group_recommender.get_individual_predictions(),
top_recommendation=group_recommender.get_top_recommendation(),
ranking_weights=weights,
)
st.session_state.sw_explanations = explanations
st.success("✅ Counterfactual explanation process finished.")
except Exception as e:
st.error(f"An error occurred: {e}")
# Display Sliding Window Results
if "sw_explanations" in st.session_state:
st.markdown("")
st.header("Explanation Results")
explanations = st.session_state.sw_explanations
st.markdown("")
if not explanations:
st.warning(
"No counterfactual explanation could be found with the current configuration."
)
else:
st.write(
"Found the following minimal sets of items that act as explanations:"
)
metric_helpers = {
"Popularity": "Measures the item's overall appeal across ALL users in the dataset. A high score indicates the item is globally popular.",
"Intensity": "The proportion (percentage) of group members who have interacted with this item. A high score means it is a widely shared experience within the group.",
"Rating": "The group's average rating for this item, normalized to a 0-1 scale. A high score indicates the group collectively enjoyed this item.",
"Relevance": "Measures how influential this item was for the recommendation of the target item, based on the model's prediction scores for the group members.",
"Trend": "The proportion of the group's interactions with this item that occurred during its detected 'hype periods' (i.e., peaks in its popularity over time).",
"Composite Score": "The final weighted sum of all other metrics. This score is used to rank items to find the most likely counterfactual explanation.",
}
for call, exp_data in explanations.items():
exp_items = exp_data["items"] # type: ignore
new_rec = exp_data["new_rec"] # type: ignore
metrics = exp_data.get("metrics", {})
st.info(
f"If the group had NOT interacted with **{exp_items}**, the recommended item **{target_item}** would have been removed from the list. In its place, item **{new_rec}** would have been recommended instead."
)
if metrics:
with st.container(border=True):
st.markdown("##### 🔎 Contributing Item Metrics")
# Loop through each item in the explanation
for item_id, item_scores in metrics.items():
st.subheader(f"Item: `{item_id}`")
cols = st.columns(3)
# Define a consistent order for displaying metrics
display_order = [
"Popularity",
"Intensity",
"Rating",
"Relevance",
"Trend",
"Composite Score",
]
col_index = 0
for metric_name in display_order:
if metric_name in item_scores:
# Place each metric in the next available column
with cols[col_index % 3]:
st.metric(
label=metric_name,
value=f"{item_scores[metric_name]:.3f}",
help=metric_helpers.get(metric_name),
)
col_index += 1
st.markdown("") # Add a separator for the next explanation
# Content for "Rule-Based Explanation (EXPGRS)"
elif chosen_explainer == explainer_options[1]:
st.header("Rule-Based Explanation (EXPGRS)")
st.markdown(
"""
This method calculates the **Model Fidelity**: the percentage of the Top-N list that can be explained by pre-computed association rules from cached files.
"""
)
st.markdown("##### Configuration")
col1, col2, col3, col4 = st.columns(4)
with col1:
min_support = st.slider("Minimum Support", 0.08, 0.30, 0.10, key="rb_support")
with col2:
min_confidence = st.slider(
"Minimum Confidence", 0.1, 1.0, 0.1, key="rb_confidence"
)
with col3:
rating_threshold = st.slider(
"Rating Threshold", 0.5, 5.0, 1.0, step=0.5, key="rb_rating_threshold"
)
with col4:
min_members = st.slider(
"Minimum Satisfied Members",
1,
len(group_members) if group_members else 1,
2,
key="rb_members",
)
# Show the expected filename pattern
expected_filename = f"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}"
if st.button("Generate Rule-Based Explanation", key="rb_button"):
cached_rules = None
with st.spinner("Loading cached association rules and finding explanations..."):
try:
# Try to load cached rules first
cached_data_rules = load_cached_data_rules(
min_support, min_confidence, rating_threshold
)
if cached_data_rules:
cached_rules = cached_data_rules.get("rules")
if cached_rules is None:
st.error(
f"⚠️ **Cached rules not found!**\n\n"
f"Could not find a cached rules file with the pattern: `{expected_filename}` "
f"in the `cached_rules/` folder.\n\n"
f"Please ensure the file exists with the correct naming pattern and try again."
)
st.stop()
st.success(
f"✅ Successfully loaded cached rules from: `{expected_filename}`"
)
# Get user history
user_history = get_user_history(rating_threshold)
# Create explainer with cached rules
explainer = RuleBasedGroupRecExplainer(
rules=cached_rules,
data=data_reader,
pool_recommendations=recommended_items,
members=group_members,
user_history=user_history,
min_members_threshold=min_members,
)
fidelity_score = explainer.find_explanation()
advanced_fidelity_score = explainer.compute_group_fidelity_advanced()
explanation_details = explainer.get_explanation_details()
explanation_results = {
"fidelity": fidelity_score,
"advanced_fidelity": advanced_fidelity_score,
"details": explanation_details,
}
st.session_state.expgrs_results = explanation_results
st.success(
"✅ Rule-based explanations generated successfully using cached rules!"
)
except Exception as e:
st.error(f"An error occurred while processing cached rules: {e}")
# Display Rule-Based Results
if "expgrs_results" in st.session_state:
st.markdown("")
st.header("Explanation Results")
evaluator = ExplanationEvaluator()
metrics = evaluator.evaluate(
st.session_state.expgrs_results, explainer_type="EXPGRS"
)
# Display Metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
"Explanation Fidelity",
f"{metrics['fidelity']:.2%}",
help="The percentage of recommended items that could be explained by at least one rule satisfying the minimum member threshold.",
)
with col2:
st.metric(
"Advanced Explanation Fidelity",
f"{st.session_state.expgrs_results['advanced_fidelity']:.2%}",
help="A stricter fidelity where every group member must have seen at least one item from the rule's antecedent.",
)
with col3:
st.metric(
"Explanation Diversity (GILD)",
f"{metrics['gild']:.4f}",
help="Gaussian Intra-List Diversity. Measures how varied the explanation are. Higher is better.",
)
with st.expander(
"🔍 View Detailed Explanations for the Group", expanded=True
):
details = st.session_state.get("expgrs_details", {})
explained_items_dict = {
item: exp for item, exp in details.items() if exp
}
if not explained_items_dict:
st.write(
"No items in the list could be explained with the current settings."
)
else:
st.write(
"The following items were explained by summarizing the found rules:"
)
for item_id, rules_list in explained_items_dict.items():
with st.container(border=True):
st.subheader(f"Item {item_id} was recommended because...")
# Ensure cached_rules is a DataFrame before passing
if isinstance(cached_rules, pd.DataFrame):
summary_df = summarize_explanation_rules(
rules_list, cached_rules
)
if not summary_df.empty:
summary_df["because_they_interacted_with"] = (
summary_df[
"because_they_interacted_with"
].apply(lambda fs: sorted(list(fs)))
)
# Display the clean summary table
st.dataframe(
summary_df,
use_container_width=True,
hide_index=True,
column_config={
"because_they_interacted_with": st.column_config.ListColumn(
"Interaction Pattern (Antecedent)",
help="The set of items the group previously interacted with.",
),
"num_rules": st.column_config.NumberColumn(
"Pattern Frequency",
help="How many rules were based on this specific pattern.",
),
"avg_confidence": st.column_config.ProgressColumn(
"Avg. Confidence",
help="The average confidence that this pattern leads to the recommendation.",
format="%.2f",
min_value=0,
max_value=1,
),
"avg_lift": st.column_config.NumberColumn(
"Avg. Lift",
help="How much more likely this pattern is than random chance. (Lift > 1 is good).",
format="%.2f",
),
},
)
else:
st.warning(
"Cached rules are not available as a DataFrame, so explanation summary cannot be displayed."
)
# Content for "Local Model-Agnostic (LORE4GroupRS)"
elif chosen_explainer == explainer_options[2]:
# Helper functions
def diagnose_and_create_aligned_profiles():
"""Create properly aligned item profiles for LORE4Groups"""
# Import required modules
with st.spinner("🔍 Diagnosing data alignment and creating item profiles..."):
# Step 1: Diagnose ID mismatch
tags_df = pd.read_csv(cfg.data.tags.tags_file)
# Step 2: Create mapping from consecutive to original IDs
consecutive_items = set(data_reader.dataset["itemId"].unique())
consecutive_to_original = {}
original_to_consecutive = {}
for consecutive_id in consecutive_items:
try:
original_id = data_reader.get_original_item_id(consecutive_id)
consecutive_to_original[consecutive_id] = original_id
original_to_consecutive[original_id] = consecutive_id
except (ValueError, KeyError):
continue
# Step 3: Filter and process tags
tags_df_filtered = tags_df[
tags_df["movieId"].isin(original_to_consecutive.keys())
]
if len(tags_df_filtered) == 0:
st.error("❌ No tag data matches items in ratings dataset!")
return None, None, None
# Process tags - tokenize labels
tag_sequences = []
movie_ids = []
for _, row in tags_df_filtered.iterrows():
# Keep the full label as a single tag, just normalize case
label = str(row["label"]).lower().strip()
tag_sequences.append(label)
movie_ids.append(row["movieId"])
# Create processed dataframe
tags_processed = pd.DataFrame(
{"original_movieId": movie_ids, "label": tag_sequences}
)
# Convert to consecutive IDs
tags_processed["movieId"] = tags_processed["original_movieId"].map(
original_to_consecutive
)
tags_processed = tags_processed.dropna(subset=["movieId"])
tags_processed["movieId"] = tags_processed["movieId"].astype(int)
# Step 4: Create item profiles
top_n_labels = cfg.explainer.lore4groups.top_n_labels
top_labels = (
tags_processed["label"]
.value_counts()
.nlargest(top_n_labels)
.index.tolist()
)
tags_final = tags_processed[tags_processed["label"].isin(top_labels)]
# Create item profiles (using consecutive IDs as keys)
item_profiles = tags_final.groupby("movieId")["label"].apply(set).to_dict()
item_profiles = {str(k): v for k, v in item_profiles.items()}
# Create item-label matrix
item_label_matrix = tags_final.assign(value=1).pivot_table(
index="movieId", columns="label", values="value", fill_value=0
)
item_label_matrix.index = item_label_matrix.index.astype(str)
# Step 5: Validation
# Calculate coverage
total_rating_items = len(consecutive_items)
tagged_items = len(item_profiles)
coverage = tagged_items / total_rating_items
# Display results
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Items", f"{total_rating_items:,}")
with col2:
st.metric("Tagged Items", f"{tagged_items:,}")
with col3:
st.metric("Coverage", f"{coverage:.1%}")
if coverage < 0.1:
st.warning(
"⚠️ Low tag coverage detected. Consider increasing `top_n_labels` parameter."
)
else:
st.success("✅ Good tag coverage for explanations!")
return item_profiles, item_label_matrix, tags_final
def validate_user_coverage(item_profiles):
"""Validate that group members have sufficient tag coverage"""
with st.spinner("🔍 Validating user tag coverage..."):
coverage_stats = []
for member_id in group_members:
try:
# Convert to consecutive ID
user_idx = data_reader.get_new_user_id(member_id)
user_data = data_reader.dataset[
data_reader.dataset["userId"] == user_idx
]
total_items = len(user_data)
tagged_items = sum(
1
for item_id in user_data["itemId"]
if str(item_id) in item_profiles
)
coverage = tagged_items / total_items if total_items > 0 else 0
coverage_stats.append(
{
"member_id": member_id,
"total_items": total_items,
"tagged_items": tagged_items,
"coverage": coverage,
}
)
except Exception as e:
st.warning(f"Could not validate coverage for user {member_id}: {e}")
continue
if coverage_stats:
# Display coverage table
coverage_df = pd.DataFrame(coverage_stats)
mean_coverage = coverage_df["coverage"].mean()
# Format the dataframe for display
display_df = coverage_df.copy()
display_df["coverage"] = display_df["coverage"].apply(
lambda x: f"{x:.1%}"
)
display_df.columns = [
"Member ID",
"Total Ratings",
"Tagged Items",
"Coverage",
]
st.dataframe(display_df, use_container_width=True, hide_index=True)
# Show summary metrics
col1, col2 = st.columns(2)
with col1:
st.metric("Average Coverage", f"{mean_coverage:.1%}")
with col2:
good_coverage_users = sum(
1 for stat in coverage_stats if stat["coverage"] > 0.1
)
st.metric(
"Users with >10% coverage",
f"{good_coverage_users}/{len(coverage_stats)}",
)
if mean_coverage < 0.1:
st.warning(
"⚠️ Low average tag coverage. Explanations may be limited."
)
return False
else:
st.success("✅ Sufficient tag coverage for explanations!")
return True
else:
st.error("❌ Could not validate any users' tag coverage.")
return False
def prepare_user_history(item_profiles):
"""Prepare user history in the format expected by LORE4Groups"""
user_hist = {}
for user_id_orig in group_members:
try:
user_id_consecutive = data_reader.get_new_user_id(user_id_orig)
# Get items (consecutive IDs) for this user
hist_items = set(
data_reader.dataset[
data_reader.dataset["userId"] == user_id_consecutive
]["itemId"].astype(str)
)
user_hist[user_id_orig] = hist_items
except Exception as e:
st.warning(f"Could not prepare history for user {user_id_orig}: {e}")
user_hist[user_id_orig] = set()
return user_hist
def parse_movie_genres(
movies_dat_path: Optional[str] = None,
movies_data: Optional[str] = None,
) -> Dict[str, Dict[str, Union[str, Set[str]]]]:
"""
Parses movie genres from a file or string in the `movies.dat` format.
This improved version avoids code duplication, uses an efficient generator,
and includes robust handling for malformed lines.
The format is expected to be: MovieID::Title::Genres
e.g., `1::Toy Story (1995)::Animation|Children's|Comedy`
Args:
movies_dat_path: Optional path to the `movies.dat` file.
movies_data: Optional string containing movie data.
Returns:
A dictionary mapping movie IDs to their title and a set of genres.
Returns an empty dictionary if no data source is provided or found.
"""
# Define a generator function to process lines from any source.
# This keeps the parsing logic in one place (DRY principle).
def process_lines(
lines: Iterable[str],
) -> Dict[str, Dict[str, Union[str, Set[str]]]]:
return {
parts[0]: {"title": parts[1], "genres": set(parts[2].split("|"))}
for line in lines
# The walrus operator (:=) assigns and checks in one step.
# This ensures we only process well-formed lines.
if (parts := line.strip().split("::")) and len(parts) == 3
}
# Handle file path source
if movies_dat_path and os.path.exists(movies_dat_path):
with open(movies_dat_path, "r", encoding="latin-1") as f:
# The file object 'f' is an iterator itself, which is memory-efficient
# as it reads the file line by line instead of all at once.
return process_lines(f)
# Handle direct string data source
elif movies_data:
return process_lines(movies_data.strip().split("\n"))
# Return an empty dict if no valid source is provided
return {}
def store_genre_profiles(movies_content):
"""Helper to store movie genre data in session state"""
st.session_state.movie_profiles = movies_content
return movies_content
def create_decision_tree_visualization(
clf, feature_names, item_id, item_title="Item"
):
"""Create an actual decision tree visualization using plotly"""
if not clf:
st.warning("No decision tree available for visualization")
return
st.subheader(f"Decision Tree for {item_title}")
MIN_NODE_SIZE = 30
MAX_NODE_SIZE = 40
# Get tree structure
tree = clf.tree_
# Create nodes and edges for visualization
total_samples = tree.n_node_samples[0]
nodes = []
edges = []
positions = {}
def add_node(node_id, depth=0, pos_x=0.0):
"""Recursively add nodes and calculate positions"""
num_samples = tree.n_node_samples[node_id]
proportion = num_samples / total_samples
scaled_size = MIN_NODE_SIZE + (proportion * (MAX_NODE_SIZE - MIN_NODE_SIZE))
# Node information
if tree.feature[node_id] != -2: # Not a leaf
feature_name = feature_names[tree.feature[node_id]]
threshold = tree.threshold[node_id]
node_text = f"{feature_name}<br>≤ {threshold:.2f}"
node_color = "#e3f2fd" # Light blue for internal nodes
else: # Leaf node
values = tree.value[node_id][0]
predicted_class = np.argmax(values)
confidence = values[predicted_class] / np.sum(values)
node_text = f"Predict: {'LIKE' if predicted_class == 1 else 'DISLIKE'}<br>Conf: {confidence:.2f}"
node_color = (
"#c8e6c9" if predicted_class == 1 else "#ffcdd2"
) # Green for like, red for dislike
nodes.append(
{
"id": node_id,
"text": node_text,
"x": pos_x,
"y": -depth,
"color": node_color,
"size": scaled_size,
}
)
positions[node_id] = (pos_x, -depth)
# Add children
left_child = tree.children_left[node_id]
right_child = tree.children_right[node_id]
if left_child != -1:
# Calculate positions for children
left_x = pos_x - 1.5 / (depth + 1)
right_x = pos_x + 1.5 / (depth + 1)
add_node(left_child, depth + 1, left_x)
add_node(right_child, depth + 1, right_x)
# Add edges
edges.extend(
[
{"from": node_id, "to": left_child, "label": "Yes"},
{"from": node_id, "to": right_child, "label": "No"},
]
)
# Build the tree structure
add_node(0)
# Create plotly figure
fig = go.Figure()
# Add edges
for edge in edges:
from_pos = positions[edge["from"]]
to_pos = positions[edge["to"]]
# Add line
fig.add_trace(
go.Scatter(
x=[from_pos[0], to_pos[0]],
y=[from_pos[1], to_pos[1]],
mode="lines",
line=dict(color="gray", width=2),
showlegend=False,
hoverinfo="skip",
)
)
# Add edge label
mid_x = (from_pos[0] + to_pos[0]) / 2
mid_y = (from_pos[1] + to_pos[1]) / 2
fig.add_annotation(
x=mid_x,
y=mid_y,
text=edge["label"],
showarrow=False,
font=dict(size=10, color="gray"),
bgcolor="white",
bordercolor="gray",
borderwidth=1,
)
# Add nodes
for node in nodes:
fig.add_trace(
go.Scatter(
x=[node["x"]],
y=[node["y"]],
mode="markers+text",
marker=dict(
size=node["size"],
color=node["color"],
line=dict(color="black", width=2),
),
text=node["text"],
textposition="middle center",
showlegend=False,
hoverinfo="text",
hovertext=node["text"],
)
)
# Update layout
fig.update_layout(
# title=f"Decision Tree Structure for {item_title}",
showlegend=False,
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
height=500,
plot_bgcolor="white", # TODO: ADAPT TO DARK MODE
)
st.plotly_chart(fig, width="stretch", key=f"decision_tree_{item_id}")
def evaluate_rule_for_item(
rule: str, item_profiles: Dict, item_id: str, movie_genres: Optional[set] = None
) -> bool:
"""
Evaluate if a rule condition is satisfied for a specific item.
Args:
rule: The technical rule string (e.g., "classic > 0.50" or "genre: `Comedy`")
item_profiles: Dictionary of item profiles with features/tags
item_id: The specific item ID being evaluated
movie_genres: Set of genres for the movie
Returns:
Boolean indicating if the rule condition is satisfied
"""
# Get item's features/profile
item_features = item_profiles.get(str(item_id), set())
# Handle genre rules
genre_match = re.search(r"genre: `([^`]+)`", rule)
if genre_match:
target_genre = genre_match.group(1).lower()
if movie_genres:
movie_genres_lower = {g.lower() for g in movie_genres}
if "Does NOT have" in rule:
return target_genre not in movie_genres_lower
else: # "Has genre:"
return target_genre in movie_genres_lower
return False
# Handle threshold rules (e.g., "classic > 0.50")
threshold_match = re.match(r"(.+?)\s*([<>]=?)\s*(\d+\.?\d*)", rule)
if threshold_match:
feature, operator, threshold_str = threshold_match.groups()
feature = feature.strip()
# Check if item has this feature (assuming binary features for now)
has_feature = feature in item_features
if operator in (">", ">="):
# Rule expects feature to be present (> 0.5 typically means "has feature")
return has_feature
else: # "<", "<="
# Rule expects feature to be absent (<= 0.5 typically means "doesn't have feature")
return not has_feature
# Handle tag rules
if "Has tag:" in rule and "`" in rule:
tag = rule.split("`")[1]
return tag in item_features
elif "Does NOT have tag:" in rule and "`" in rule:
tag = rule.split("`")[1]
return tag not in item_features
# Fallback - if we can't parse the rule, assume it's not satisfied
return False
def create_decision_path(
factual_rules, item_id, item_profiles, item_title="Item", movie_genres=None
):
"""Fixed decision path that properly evaluates rule conditions"""
if not factual_rules:
st.warning("No factual rules found for this item")
return
st.subheader(f"Why {item_title} is recommended:")
# Show movie genres if available
if movie_genres:
st.markdown(f"**Movie Genres:** {', '.join(movie_genres)}")
st.markdown("")
# Create path visualization
path_container = st.container()
with path_container:
# 1. Start Node: "Analyzing Item"
st.markdown(
"""
<div style='background-color: #e3f2fd; border-radius: 10px; padding: 15px; text-align: center; border: 2px solid #2196f3; max-width: 400px; margin: auto;'>
<strong>Analyzing Item</strong>
</div>
""",
unsafe_allow_html=True,
)
# 2. Arrow pointing down
st.markdown(
"<div style='text-align: center; font-size: 30px; margin: 5px 0;'>↓</div>",
unsafe_allow_html=True,
)
# 3. Process each rule and display it vertically
st.markdown(
"<h5 style='text-align: center;'>Evaluating Rules</h5>",
unsafe_allow_html=True,
)
for i, rule in enumerate(factual_rules):
formatted_rule = format_rule_with_context(rule, movie_genres or set())
# Evaluate the rule for the specific item
rule_satisfied = evaluate_rule_for_item(
rule, item_profiles, item_id, movie_genres
)
if rule_satisfied:
bg_color = "#c8e6c9" # Light green for satisfied
border_color = "#4caf50"
result = "YES ✓"
else:
bg_color = "#ffcdd2" # Light red for unsatisfied
border_color = "#f44336"
result = "NO ✗"
# Each call to st.markdown will add a new element below the previous one
st.markdown(
f"""
<div style='background-color: {bg_color}; border-radius: 10px; padding: 15px; margin: 10px auto; border: 2px solid {border_color}; max-width: 600px;'>
<strong>Rule {i + 1}:</strong> {formatted_rule["readable"]}<br>
<small style='color: #666;'>{formatted_rule["context"]}</small><br>
<span style='color: {border_color}; font-weight: bold;'>Result: {result}</span>
</div>
""",
unsafe_allow_html=True,
)
# 4. Arrow pointing down to the final result
st.markdown(
"<div style='text-align: center; font-size: 30px; margin: 5px 0;'>↓</div>",
unsafe_allow_html=True,
)
# 5. Final Recommendation
# 5. Final Recommendation
st.markdown(
"""
<div style='background-color: #4caf50; color: white; border-radius: 10px; padding: 20px 15px; text-align: center; font-weight: bold; font-size: 18px; max-width: 400px; margin: auto;'>
🎬 RECOMMEND TO GROUP
</div>
""",
unsafe_allow_html=True,
)
st.divider(width=1)
def create_counterfactuals(counterfactuals, item_title="Item", movie_genres=None):
"""Enhanced counterfactual explanations that consider genres"""
if not counterfactuals:
st.info("No counterfactual scenarios available")
return
st.subheader(f"What would make {item_title} NOT recommended:")
# Show current genres context
if movie_genres:
st.markdown(f"**Current Genres:** {', '.join(movie_genres)}")
st.markdown(
"*Consider how changing these characteristics would affect the recommendation:*"
)
st.markdown("")
for member_id, scenarios in counterfactuals.items():
with st.expander(
f"Member **{member_id}** would reject if:", expanded=False
):
for scenario_idx, scenario in enumerate(
scenarios, 1
): # Max 2 scenarios
st.markdown(f"**Counterfactual {scenario_idx}:**")
scenario_text = "If "
conditions = []
for rule in scenario:
formatted_rule = format_rule_with_context(
rule, movie_genres or set()
)
conditions.append(formatted_rule["readable"].lower())
scenario_text += " **AND** ".join(conditions)
st.markdown(f"- {scenario_text}")
st.markdown(" → **Result: DO NOT RECOMMEND** ✗")
# Add genre-based insight if available
if movie_genres:
genre_insight = get_genre_counterfactual_insight(
scenario, movie_genres
)
if genre_insight:
st.markdown(f" 💡 *Genre Insight: {genre_insight}*")
st.markdown("")
GENRE_INSIGHT_PATTERNS = {
"adventure": {
"action": "consistent with adventure themes",
"thriller": "consistent with adventure themes",
"quest": "central to adventure narratives",
},
"comedy": {
"family": "aligns with comedy preferences",
"children": "aligns with comedy preferences",
"humor": "a key element of comedy",
},
"drama": {
"romance": "fits dramatic storytelling",
"emotional": "fits dramatic storytelling",
"relationships": "focuses on character drama",
},
"romance": {
"love story": "the central focus of the genre",
"relationships": "explores romantic dynamics",
},
# New Additions
"action": {
"stunts": "a hallmark of the action genre",
"explosions": "common in action films",
"superhero": "a popular subgenre of action",
},
"sci-fi": {
"space": "aligns with sci-fi settings",
"aliens": "a common sci-fi theme",
"futuristic": "characteristic of sci-fi worlds",
"dystopia": "a popular sci-fi subgenre",
},
"horror": {
"scary": "defines the horror experience",
"suspense": "builds tension in horror films",
"monster": "a classic horror element",
"supernatural": "common in horror narratives",
},
"thriller": {
"suspense": "essential for building tension",
"mystery": "central to the thriller plot",
"crime": "a common element in thrillers",
},
"fantasy": {
"magic": "a core component of fantasy",
"mythology": "often inspires fantasy worlds",
"epic": "relates to the scale of fantasy stories",
},
"mystery": {
"detective": "central to the investigation",
"whodunit": "defines the mystery plot",
"crime": "often the basis of a mystery",
},
}
def format_rule_with_context(rule: str, movie_genres: Optional[set] = None) -> dict:
"""
Enhanced rule formatting that considers movie genres using a more structured
and scalable approach.
Args:
rule: The technical rule string (e.g., "classic > 0.50").
movie_genres: A set of genres for the movie being explained.
Returns:
A dictionary with the formatted, human-readable rule and context.
"""
# Get the base formatting from the simple formatter
formatted = format_rule_with_context_simple(rule)
# Early exit if there are no genres or the rule is not formatted
if not movie_genres or not formatted:
return formatted
# Prepare data once for efficient searching
movie_genres_lower = {g.lower() for g in movie_genres}
rule_lower = rule.lower()
found_contexts = set()
# Add context by checking for relationships
for genre in movie_genres_lower:
# a. Check if the rule is directly about one of the movie's genres
if genre in rule_lower:
found_contexts.add(f"relates to the movie's '{genre.title()}' genre")
# b. Check for pre-defined insight patterns based on the movie's genres
if genre in GENRE_INSIGHT_PATTERNS:
for tag, insight in GENRE_INSIGHT_PATTERNS[genre].items():
if tag in rule_lower:
found_contexts.add(insight)
# Append all unique, found contexts to the original context string
if found_contexts:
# Sort for consistent output order
additional_context = ". ".join(sorted(list(found_contexts)))
formatted["context"] += f" ({additional_context})"
return formatted
def format_rule_with_context_simple(rule: str) -> dict:
"""
Translates a technical rule string into a human-readable dictionary
with a clear 'readable' text and 'context'.
"""
# First, strip the member support count, if it exists, to analyze the core rule
rule_only = re.sub(r"\s*\(\d+/\d+ members\)", "", rule).strip()
# Enhanced Genre Rules
match = re.search(r"genre: `([^`]+)`", rule_only)
if match:
genre = match.group(1)
if "Does NOT have" in rule_only:
return {
"readable": f"the movie is NOT in the '{genre}' genre",
"context": f"This item is not classified as '{genre}'.",
"type": "negative",
}
else: # "Has genre:"
return {
"readable": f"the movie is in the '{genre}' genre",
"context": f"This item is classified as '{genre}'.",
"type": "positive",
}
# Enhanced Tag/Feature Rules (e.g., "classic > 0.50")
match = re.match(r"(.+?)\s*([<>]=?)\s*(\d+\.?\d*)", rule_only)
if match:
feature, operator, _ = match.groups()
feature = feature.strip()
if operator in (">", ">="):
return {
"readable": f"it has the '{feature}' characteristic",
"context": f"The item is associated with the '{feature}' tag.",
"type": "positive",
}
else: # "<", "<="
return {
"readable": f"it does NOT have the '{feature}' characteristic",
"context": f"The item is not associated with the '{feature}' tag.",
"type": "negative",
}
# Fallback for any rule format that doesn't match the above
return {
"readable": rule,
"context": "A technical condition was met.",
"type": "unknown",
}
def get_genre_counterfactual_insight(scenario_rules, movie_genres):
"""Generate genre-based insights for counterfactual scenarios"""
if not movie_genres:
return None
genre_insights = {
"action": "More action-oriented content might be preferred.",
"adventure": "More exciting, journey-based stories are preferred.",
"animation": "Animated stories could capture a wider demographic.",
"children's": "Content suitable for a younger audience is a key area for growth.",
"comedy": "Lighter, humorous content would be better.",
"crime": "Gritty, crime-focused narratives are a strong draw.",
"documentary": "Factual, informative content would attract viewers interested in real-world stories.",
"drama": "More serious, character-driven stories are favored.",
"fantasy": "Imaginative, fantastical worlds and magic systems would be well-received.",
"film-noir": "Stylized, cynical film-noir narratives would appeal to a niche audience.",
"horror": "Content that creates a sense of dread and fear would be effective.",
"musical": "Incorporating musical numbers could broaden the audience.",
"mystery": "Intriguing mysteries that keep the audience guessing are popular.",
"romance": "Romantic elements would improve appeal.",
"sci-fi": "Science fiction elements would be welcome.",
"thriller": "Suspenseful content would be more engaging.",
"war": "Stories set against the backdrop of historical conflicts would be compelling.",
"western": "Classic western themes of frontier justice would resonate well.",
}
# Look for genre-related patterns in the counterfactual rules
for genre in movie_genres:
genre_lower = genre.lower()
if genre_lower in genre_insights:
# Check if any rule contradicts this genre
for rule in scenario_rules:
if genre_lower in rule.lower() and (
"not" in rule.lower() or "<=" in rule
):
return genre_insights[genre_lower]
return None
def create_genre_group_analysis(
factual_rules_categorized, counterfactuals, movie_genres=None
):
"""Enhanced group consensus analysis with genre insights"""
st.subheader("Group Consensus Analysis")
# Show genre context at the top
if movie_genres:
st.markdown(f"**Analyzing:** Movie with genres: {', '.join(movie_genres)}")
st.markdown("")
col1, col2 = st.columns(2)
with col1:
st.markdown("#### **Why the GROUP agrees:**")
# print(factual_rules_categorized)
# Display Unanimous Rules
if factual_rules_categorized.get("unanimous"):
st.markdown("##### Unanimous Agreement")
for rule in factual_rules_categorized["unanimous"]:
st.success(f"{rule}") # Using success box for emphasis
# Display Majority Rules
if factual_rules_categorized.get("majority"):
st.markdown("##### Majority Consensus")
for rule in factual_rules_categorized["majority"]:
st.info(f"{rule}") # Using info box for emphasis
# Display Influential Minority Rules
if factual_rules_categorized.get("minority"):
st.markdown("##### Minority Factors")
for rule in factual_rules_categorized["minority"]:
st.markdown(f"{rule}")
if not any(factual_rules_categorized.values()):
st.markdown("No clear group consensus found.")
with col2:
st.markdown("#### Individual member concerns:")
if counterfactuals:
member_count = len(counterfactuals)
st.markdown(f"*{member_count} members have specific concerns (φ):*")
st.markdown('<div class="compact-list">', unsafe_allow_html=True)
for member_id in counterfactuals.keys():
scenarios_count = len(counterfactuals[member_id])
st.markdown(
f"- Member {member_id}: {scenarios_count} Counterfactual(s)"
)
# Add genre-specific member insights
if movie_genres:
member_genre_insight = get_member_genre_preference(
counterfactuals[member_id], movie_genres
)
if member_genre_insight:
st.badge(f"💡 *{member_genre_insight}*")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.markdown("No individual concerns identified")
def get_member_genre_preference(member_scenarios, movie_genres):
"""Get genre preference insight for individual member"""
if not movie_genres:
return None
# Look at what this member would prefer instead
rejected_features = []
for scenario in member_scenarios:
for rule in scenario:
if any(genre.lower() in rule.lower() for genre in movie_genres):
rejected_features.append(rule)
if rejected_features:
return "Prefers different genre characteristics"
return None
def display_lore4groups(domain_name="movie", item_type="Movie"):
"""LORE4Groups display with tree visualization and genre integration"""
if "lore_explanation" not in st.session_state:
st.warning(
"No LORE4Groups explanations available. Please generate explanations first."
)
return
results = st.session_state.lore_explanation
evaluator = ExplanationEvaluator()
metrics = evaluator.evaluate(results, explainer_type="LORE4Groups")
details = results.get("details", {})
# Fidelity metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(f"{item_type}s Explained", len(details))
with col2:
fidelity_score = metrics.get("fidelity", 0.0)
if fidelity_score > 0.7:
font_color = "rgb(52, 142, 79)"
background_color = "rgba(224, 240, 229, 0.7)"
quality_text = "High"
elif fidelity_score > 0.3:
font_color = "rgb(209, 126, 32)"
background_color = "rgba(251, 236, 219, 0.7)"
quality_text = "Moderate"
else:
font_color = "rgb(217, 72, 63)"
background_color = "rgba(252, 228, 226, 0.7)"
quality_text = "Low"
# 2. Create the HTML string for the tag
tag_html = f"""
<div style="
background-color: {background_color};
color: {font_color};
padding: 0.25rem 0.5rem;
border-radius: 0.5rem;
font-weight: 600;
font-size: 1.5rem;
display: inline-block;
">
{quality_text}
</div>
"""
percentage_value = f"{fidelity_score * 100:.1f}%"
st.markdown(
f"""
<div style="display: flex; flex-direction: column">
<div style="font-size: 0.875rem; padding-top: 0.25rem">
Fidelity score
</div>
<div style="display: flex; align-items: baseline; gap: 8px; padding-bottom: 1rem;">
<div style="font-size: 2.25rem; font-weight: 400;">{percentage_value} </div>
{tag_html}
</div>
""",
unsafe_allow_html=True,
)
with col3:
st.metric(
"Explanation Diversity (GILD)",
f"{metrics.get('gild', 0.0):.4f}",
help="Gaussian Intra-List Diversity. Measures how varied the generated factual rules are. Higher is better.",
)
# Individual item explanations
if details:
for item_id, exp_data in details.items():
decision_path_rules = exp_data.get("decision_path", [])
factual_rules = exp_data.get("group_factual_rule", [])
counterfactuals = exp_data.get("individual_counterfactuals", {})
# Get item title and genres
item_title, movie_genres = get_item_metadata(
item_id, item_type=item_type
)
# Main explanation container
with st.expander(
f"**{item_id}** - {item_title} - Complete Explanation",
expanded=True,
):
threshold_info = exp_data.get("threshold_info")
if threshold_info and threshold_info.get("was_overridden"):
original_thresh = threshold_info.get(
"original_threshold", "N/A"
)
final_thresh = threshold_info.get("final_threshold", 0)
st.info(
f"**Note:** Your selected 'Like' threshold of `{original_thresh}` "
f"did not create enough class diversity for this item's local neighborhood. "
f"To generate a valid explanation, the threshold was automatically adjusted "
f"to the local mean rating of **`{final_thresh:.2f}`**.",
icon="",
)
# Four enhanced tabs
tab1, tab2, tab3, tab4 = st.tabs(
[
"🌳 Decision Tree",
"📋 Decision Path",
"🔄 Alternatives",
"👥 Group Analysis",
]
)
with tab1:
# Show the actual decision tree if available
if "decision_tree" in exp_data and exp_data["decision_tree"]:
clf = exp_data["decision_tree"]
feature_names = exp_data.get("feature_names", [])
create_decision_tree_visualization(
clf, feature_names, item_id, item_title
)
else:
st.info(
"Decision tree visualization not available. This requires storing the trained classifier."
)
st.markdown("""
**To enable tree visualization:**
1. Modify the LORE4Groups explainer to store the trained DecisionTreeClassifier
2. Include feature names in the explanation results
3. The tree will show the actual decision nodes and paths
""")
with tab2:
create_decision_path(
decision_path_rules,
item_id,
st.session_state.lore_item_profiles,
item_title,
movie_genres,
)
with tab3:
create_counterfactuals(
counterfactuals, item_title, movie_genres
)
with tab4:
create_genre_group_analysis(
factual_rules, counterfactuals, movie_genres
)
else:
st.info(
"No explanations were generated. This may indicate insufficient data overlap between group members' preferences."
)
def get_item_metadata(item_id, data_reader=None, item_type="Item"):
"""Enhanced metadata retrieval that includes genre information"""
try:
# Try to get movie title and genres from session state
if "movie_profiles" in st.session_state:
movie_info = st.session_state.movie_profiles.get(str(item_id), {})
if "title" in movie_info and "genres" in movie_info:
return movie_info["title"], movie_info["genres"]
elif "title" in movie_info:
return movie_info["title"], set()
if data_reader and hasattr(data_reader, "get_original_item_id"):
original_id = data_reader.get_original_item_id(int(item_id))
return f"{item_type} {original_id}", set()
return f"{item_type} {item_id}", set()
except Exception as e:
st.warning(f"Metadata retrieval failed for item {item_id}: {e}")
return f"{item_type} {item_id}", set()
def is_genre_related_rule(rule):
"""Check if a rule relates to movie genres"""
genre_keywords = [
"comedy",
"fantasy",
"thriller",
"romance",
"drama",
"musical",
"crime",
"war",
"mystery",
"sci-fi",
"western",
"horror",
"children's",
"animation",
"film-noir",
"adventure",
"action",
"documentary",
]
return any(keyword in rule.lower() for keyword in genre_keywords)
def analyze_genre_usage_in_explanations(details):
"""Analyze how genres are used across all explanations"""
all_genres = set()
genre_based_rules = 0
total_rules = 0
genre_frequency = Counter()
for exp_data in details.values():
rules = exp_data.get("group_factual_rule", [])
total_rules += len(rules)
for rule in rules:
if is_genre_related_rule(rule):
genre_based_rules += 1
# Extract genre from rule (simplified)
for genre in [
"comedy",
"fantasy",
"thriller",
"romance",
"drama",
"musical",
"crime",
"war",
"mystery",
"sci-fi",
"western",
"horror",
"children's",
"animation",
"film-noir",
"adventure",
"action",
"documentary",
]:
if genre in rule.lower():
all_genres.add(genre)
genre_frequency[genre] += 1
break
return {
"total_genres": len(all_genres),
"genre_based_rules": int(
(genre_based_rules / total_rules * 100) if total_rules > 0 else 0
),
"top_genres": [genre for genre, _ in genre_frequency.most_common(3)],
}
def analyze_rule_features(factual_rules, movie_genres):
"""Analyze the features used in rules"""
if not factual_rules:
return {}
feature_types = {
"genre_related": 0,
"tag_related": 0,
"threshold_based": 0,
"other": 0,
}
for rule in factual_rules:
if any(genre.lower() in rule.lower() for genre in movie_genres):
feature_types["genre_related"] += 1
elif "tag" in rule.lower():
feature_types["tag_related"] += 1
elif ">" in rule or "<" in rule:
feature_types["threshold_based"] += 1
else:
feature_types["other"] += 1
return feature_types
# Main LORE4Groups interface
st.header("Local Rule-Based Explanations (LORE4Groups)")
st.markdown("""
**LORE4Groups** generates explanations by building local decision trees around each recommended item:
1. **Local Neighborhood**: For each recommended item, find similar items that group members have rated
2. **Local Model**: Build a decision tree using item features to predict ratings in this neighborhood
3. **Rule Extraction**: Extract interpretable rules explaining why the item was recommended
4. **Group Consensus**: Aggregate individual explanations into group-level rules
The method provides both **factual** rules (why recommended) and **counterfactual** rules (what would change the recommendation).
""")
# Configuration section
st.markdown("##### Configuration")
col1, col2 = st.columns(2)
with col1:
rating_threshold = st.slider(
"Rating Threshold for 'Like'",
min_value=2.0,
max_value=5.0,
value=3.0,
step=0.1,
help="Ratings above this threshold are considered 'liked'",
)
similarity_threshold = st.slider(
"Similarity Threshold",
min_value=0.0,
max_value=1.0,
value=0.1,
step=0.05,
help="Minimum Jaccard similarity to consider items as neighbors",
)
with col2:
n_similar_for_tree = st.slider(
"Similar Items for Tree",
min_value=20,
max_value=200,
value=50,
step=10,
help="Number of similar items to use for building local decision tree",
)
min_neighbors = st.slider(
"Minimum Neighbors",
min_value=3,
max_value=20,
value=5,
step=1,
help="Minimum number of similar items needed to build explanation",
)
# Initialize session state for LORE profiles
if "lore_item_profiles" not in st.session_state:
st.session_state.lore_item_profiles = None
st.session_state.lore_item_matrix = None
st.session_state.lore_tags_final = None
# Button to prepare data
if st.button("🔧 Prepare Item Profiles", key="prepare_profiles_btn"):
profiles_result = diagnose_and_create_aligned_profiles()
if profiles_result[0] is not None:
st.session_state.lore_item_profiles = profiles_result[0]
st.session_state.lore_item_matrix = profiles_result[1]
st.session_state.lore_tags_final = profiles_result[2]
st.success("✅ Item profiles prepared successfully!")
else:
st.error("❌ Failed to prepare item profiles. Check data compatibility.")
# Main explanation generation
if st.session_state.lore_item_profiles is not None:
# Validate user coverage
with st.expander("🔍 Validate Group Coverage", expanded=False):
coverage_ok = validate_user_coverage(st.session_state.lore_item_profiles)
st.session_state.lore_coverage_validated = coverage_ok
# Generate explanations button
if st.button("Generate LORE4Groups Explanations", key="generate_lore_btn"):
# Check if we have recommendations
if not recommended_items:
st.error(
"❌ No recommendations available. Please generate recommendations first."
)
st.stop()
if st.session_state.lore_item_matrix is None:
st.error(
"❌ Item profiles are not fully prepared. Please click 'Prepare Item Profiles' again."
)
st.stop()
with st.spinner(
"Building local decision trees and generating explanations..."
):
try:
cfg.explainer.lore4groups.rating_threshold_for_like = (
rating_threshold
)
cfg.explainer.lore4groups.similarity_threshold = (
similarity_threshold
)
cfg.explainer.lore4groups.n_similar_for_tree = n_similar_for_tree
cfg.explainer.lore4groups.min_neighbors = min_neighbors
# Filter recommendations to those with tag profiles
explainable_recs = [
str(item)
for item in recommended_items
if str(item) in st.session_state.lore_item_profiles
]
if not explainable_recs:
st.error(
"❌ No recommended items have sufficient tag data for explanation!"
)
st.stop()
# Prepare user history
user_hist = prepare_user_history(
st.session_state.lore_item_profiles
)
movies_content = parse_movie_genres(
movies_dat_path="datasets/ml-1m/movies.dat"
)
store_genre_profiles(movies_content)
genre_profiles = {
movie_id: data["genres"]
if isinstance(data["genres"], set)
else set(data["genres"].split("|"))
if isinstance(data["genres"], str)
else set()
for movie_id, data in movies_content.items()
}
if st.session_state.lore_item_matrix is None:
st.error(
"❌ Item label matrix is not prepared. Please click 'Prepare Item Profiles' again."
)
st.stop()
explainer = LORE4GroupsExplainer(
st.session_state.lore_item_profiles,
st.session_state.lore_item_matrix,
cfg,
genre_profiles=genre_profiles,
)
# Prepare data structures
results = explainer.find_explanation(
st.session_state.recommended_items,
group_members,
user_hist,
data_reader.dataset,
model=model,
data_reader=data_reader,
)
# Store results
st.session_state.lore_explanation = results
st.success("🎉 LORE4Groups explanations generated successfully!")
except Exception as e:
st.error(f"❌ Error generating explanations: {str(e)}")
st.exception(e)
else:
st.info("👆 Please prepare item profiles first before generating explanations.")
display_lore4groups(domain_name="movie", item_type="Movie")