617 lines
20 KiB
Plaintext
617 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7fa3d250",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "6b55c6e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Local application/library specific imports\n",
|
|
"from pygrex.config import cfg\n",
|
|
"from pygrex.data_reader import DataReader, GroupInteractionHandler\n",
|
|
"from pygrex.evaluator import run_evaluation_with_proper_split\n",
|
|
"from pygrex.explain.groups.lore4groups_explainer import LORE4GroupsExplainer\n",
|
|
"from pygrex.models import ALS\n",
|
|
"from pygrex.recommender import GroupRecommender\n",
|
|
"from pygrex.utils import AggregationStrategy\n",
|
|
"from pygrex.evaluator import ExplanationEvaluator\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import pandas as pd\n",
|
|
"import os\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "adbf9967",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Data preparation complete.\n",
|
|
"\n",
|
|
"--- Data Summary ---\n",
|
|
"👥 Unique Users: 610\n",
|
|
"📦 Unique Items: 9,724\n",
|
|
"⭐ Total Ratings: 100,836\n",
|
|
"👨👩👧👦 Number of Groups: 17\n",
|
|
"\n",
|
|
"Processed Ratings DataFrame Head:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>userId</th>\n",
|
|
" <th>itemId</th>\n",
|
|
" <th>rating</th>\n",
|
|
" <th>timestamp</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>4.0</td>\n",
|
|
" <td>964982703</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>4.0</td>\n",
|
|
" <td>964981247</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>4.0</td>\n",
|
|
" <td>964982224</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>5.0</td>\n",
|
|
" <td>964983815</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>46</td>\n",
|
|
" <td>5.0</td>\n",
|
|
" <td>964982931</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" userId itemId rating timestamp\n",
|
|
"0 0 0 4.0 964982703\n",
|
|
"1 0 2 4.0 964981247\n",
|
|
"2 0 5 4.0 964982224\n",
|
|
"3 0 43 5.0 964983815\n",
|
|
"4 0 46 5.0 964982931"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read the ratings file.\n",
|
|
"data = DataReader(**cfg.data.test)\n",
|
|
"data.make_consecutive_ids_in_dataset()\n",
|
|
"# data.binarize(binary_threshold=1)\n",
|
|
"\n",
|
|
"# Read the file with the group ids\n",
|
|
"group_handler = GroupInteractionHandler(**cfg.data.groups)\n",
|
|
"available_groups = group_handler.read_groups(\"groupsWithHighRatings5.txt\")\n",
|
|
"print(\"✅ Data preparation complete.\\n\")\n",
|
|
"\n",
|
|
"# --- Display Data Summary ---\n",
|
|
"print(\"--- Data Summary ---\")\n",
|
|
"print(f\"👥 Unique Users: {data.num_user:,}\")\n",
|
|
"print(f\"📦 Unique Items: {data.num_item:,}\")\n",
|
|
"print(f\"⭐ Total Ratings: {len(data.get_raw_dataset()):,}\")\n",
|
|
"print(f\"👨👩👧👦 Number of Groups: {len(available_groups):,}\")\n",
|
|
"print(\"\\nProcessed Ratings DataFrame Head:\")\n",
|
|
"display(data.dataset.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5fc94aef",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 2: Model Training & Evaluation\n",
|
|
"\n",
|
|
"With the data prepared, we now select and train a recommendation model. We will use **Alternating Least Squares (ALS)**, a matrix factorization technique for implicit feedback. After training, we will evaluate its performance using a train/test split to measure its Hit Ratio and NDCG."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8c13c283",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 2.1 Model Training ---\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\usuar\\miniconda3\\envs\\pygrex-exp-grs\\Lib\\site-packages\\implicit\\cpu\\als.py:95: RuntimeWarning: OpenBLAS is configured to use 8 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n",
|
|
" check_blas_config()\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "40a7ec46978a413e80b045c0f60fbce6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Model trained successfully in 0.95 seconds!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 2.1 Model Training ---\")\n",
|
|
"\n",
|
|
"# Train the recommendation model\n",
|
|
"model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Train the model\n",
|
|
"start_time = time.time()\n",
|
|
"model.fit(data)\n",
|
|
"end_time = time.time()\n",
|
|
"training_time = end_time - start_time\n",
|
|
"\n",
|
|
"print(f\"✅ Model trained successfully in {training_time:.2f} seconds!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "047fe521",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"\\n--- 2.2 Offline Model Evaluation ---\")\n",
|
|
"# For evaluation, a new model instance must be created.\n",
|
|
"# The evaluation function handles its own internal data splitting and training.\n",
|
|
"eval_model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Define evaluation parameters\n",
|
|
"test_size = 0.2\n",
|
|
"top_n = 10\n",
|
|
"\n",
|
|
"print(f\"Running evaluation with a {test_size*100:.0f}% test split (Top-{top_n})...\")\n",
|
|
"\n",
|
|
"# Run the evaluation\n",
|
|
"evaluation_scores = run_evaluation_with_proper_split(\n",
|
|
" data_reader=data,\n",
|
|
" model=eval_model,\n",
|
|
" test_size=test_size,\n",
|
|
" top_n=top_n,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Display evaluation results\n",
|
|
"print(\"\\n--- Evaluation Results ---\")\n",
|
|
"print(f\"Hit Ratio @{top_n}: {evaluation_scores.get('Hit Ratio', 0.0):.2%}\")\n",
|
|
"print(f\"NDCG @{top_n}: {evaluation_scores.get('NDCG', 0.0):.4f}\")\n",
|
|
"print(f\"Evaluation Time: {evaluation_scores.get('evaluation_time', 0):.1f}s\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "49cb2659",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 3: Group Recommendation\n",
|
|
"\n",
|
|
"Now that we have a trained model, we can generate recommendations for a group. We will select a group, choose an aggregation strategy to combine individual member preferences, and generate a Top-10 list of recommended items."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "0a138815",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 3. Group Recommendation ---\n",
|
|
"Generating Top-10 recommendations for group: 522_385_234_452_594\n",
|
|
"👥 Group Members: [522, 385, 234, 452, 594]\n",
|
|
"📊 Aggregation Strategy: AVG_PREDICTIONS\n",
|
|
"\n",
|
|
"✅ Recommendations generated successfully!\n",
|
|
"\n",
|
|
"Top 10 Recommended Items:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Rank</th>\n",
|
|
" <th>Item ID</th>\n",
|
|
" <th>Aggregated Score</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>543</td>\n",
|
|
" <td>4.636274</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>757</td>\n",
|
|
" <td>4.582981</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>564</td>\n",
|
|
" <td>4.504107</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>441</td>\n",
|
|
" <td>4.488708</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>5</td>\n",
|
|
" <td>379</td>\n",
|
|
" <td>4.341830</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>6</td>\n",
|
|
" <td>475</td>\n",
|
|
" <td>4.279482</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>7</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>4.268454</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>8</td>\n",
|
|
" <td>19</td>\n",
|
|
" <td>4.225248</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>9</td>\n",
|
|
" <td>748</td>\n",
|
|
" <td>4.178329</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>10</td>\n",
|
|
" <td>64</td>\n",
|
|
" <td>4.147735</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Rank Item ID Aggregated Score\n",
|
|
"0 1 543 4.636274\n",
|
|
"1 2 757 4.582981\n",
|
|
"2 3 564 4.504107\n",
|
|
"3 4 441 4.488708\n",
|
|
"4 5 379 4.341830\n",
|
|
"5 6 475 4.279482\n",
|
|
"6 7 43 4.268454\n",
|
|
"7 8 19 4.225248\n",
|
|
"8 9 748 4.178329\n",
|
|
"9 10 64 4.147735"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 3. Group Recommendation ---\")\n",
|
|
"\n",
|
|
"# Select a group and strategy\n",
|
|
"selected_group_id = available_groups[0] # Let's use the first group as an example\n",
|
|
"group_members = group_handler.parse_group_members(selected_group_id)\n",
|
|
"aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS # Use the simple average strategy\n",
|
|
"top_k = 10\n",
|
|
"\n",
|
|
"print(f\"Generating Top-{top_k} recommendations for group: {selected_group_id}\")\n",
|
|
"print(f\"👥 Group Members: {group_members}\")\n",
|
|
"print(f\"📊 Aggregation Strategy: {aggregation_strategy.name}\")\n",
|
|
"\n",
|
|
"# --- Generate Recommendations ---\n",
|
|
"# 1. Instantiate the GroupRecommender\n",
|
|
"group_recommender = GroupRecommender(data=data)\n",
|
|
"\n",
|
|
"# 2. Setup the recommendation process\n",
|
|
"group_recommender.setup_recommendation(\n",
|
|
" model=model,\n",
|
|
" members=group_members, # type: ignore\n",
|
|
" data=data,\n",
|
|
" aggregation_strategy=aggregation_strategy,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# 3. Get the final recommendation list\n",
|
|
"recommended_items = group_recommender.get_group_recommendations(top_k=top_k)\n",
|
|
"recommendation_scores = group_recommender.get_recommendation_scores()\n",
|
|
"\n",
|
|
"print(\"\\n✅ Recommendations generated successfully!\")\n",
|
|
"\n",
|
|
"# --- Display Results ---\n",
|
|
"rec_data = [\n",
|
|
" {\n",
|
|
" \"Rank\": i + 1,\n",
|
|
" \"Item ID\": item_id,\n",
|
|
" \"Aggregated Score\": recommendation_scores.get(item_id, 0.0),\n",
|
|
" }\n",
|
|
" for i, item_id in enumerate(recommended_items) # type: ignore\n",
|
|
"]\n",
|
|
"\n",
|
|
"rec_df = pd.DataFrame(rec_data)\n",
|
|
"print(f\"\\nTop {top_k} Recommended Items:\")\n",
|
|
"display(rec_df)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6268a2ed",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 4: Explanation (LORE4Groups)\n",
|
|
"\n",
|
|
"Finally, we generate an explanation for the recommended items using **LORE4Groups**, a local rule-based method. It:\n",
|
|
"- builds a local neighborhood of similar items using tag profiles\n",
|
|
"- trains a simple decision tree per item to predict 'like' vs 'not like'\n",
|
|
"- extracts interpretable rules explaining why items were recommended.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "367063db",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 4. Local Rule-Based Explanation (LORE4Groups) ---\n",
|
|
"Explanation Fidelity:\n",
|
|
"25.00%\n",
|
|
"--------------------\n",
|
|
"Explanation Diversity (GILD):\n",
|
|
"0.8871\n",
|
|
"--------------------\n",
|
|
"Items with explanations:\n",
|
|
"['475', '43']\n",
|
|
"--------------------\n",
|
|
"Sample item: 475\n",
|
|
"Decision Path (rules): ['nudity (rear) <= 0.50', 'twins <= 0.50']\n",
|
|
"Group Factual Rules: {'unanimous': [], 'majority': [], 'minority': ['70mm <= 0.50 (1/5 members)', 'franchise <= 0.50 (1/5 members)', 'futuristmoviescom <= 0.50 (1/5 members)', 'nudity (rear) <= 0.50 (1/5 members)', 'owned <= 0.50 (1/5 members)', 'seen at the cinema <= 0.50 (1/5 members)', 'sequel <= 0.50 (1/5 members)', 'twins <= 0.50 (1/5 members)']}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 4. Local Rule-Based Explanation (LORE4Groups) ---\")\n",
|
|
"\n",
|
|
"# 1) Build tag-based item profiles aligned with the ratings dataset\n",
|
|
"# ---------------------------------------------------------------\n",
|
|
"# Read tags file from config\n",
|
|
"_tags_path = cfg.data.tags.tags_file\n",
|
|
"if not os.path.exists(_tags_path):\n",
|
|
" raise SystemExit(f\"Tags file not found at: {_tags_path}\")\n",
|
|
"\n",
|
|
"# Read tags and align original -> consecutive item ids\n",
|
|
"tags_df = pd.read_csv(_tags_path)\n",
|
|
"consecutive_items = set(data.dataset[\"itemId\"].unique())\n",
|
|
"original_to_consecutive = {}\n",
|
|
"for item_consec in consecutive_items:\n",
|
|
" try:\n",
|
|
" item_orig = data.get_original_item_id(int(item_consec))\n",
|
|
" original_to_consecutive[item_orig] = int(item_consec)\n",
|
|
" except (ValueError, KeyError):\n",
|
|
" continue\n",
|
|
"\n",
|
|
"# Keep only tags for items present in ratings\n",
|
|
"tags_df = tags_df[tags_df[\"movieId\"].isin(original_to_consecutive.keys())].copy()\n",
|
|
"if len(tags_df) == 0:\n",
|
|
" raise SystemExit(\"No tag data matches items in ratings dataset.\")\n",
|
|
"\n",
|
|
"# Normalize labels (keep full label as tag, lowercase)\n",
|
|
"tags_df[\"label\"] = tags_df[\"label\"].astype(str).str.lower().str.strip()\n",
|
|
"# Map to consecutive ids\n",
|
|
"tags_df[\"movieId\"] = tags_df[\"movieId\"].map(original_to_consecutive).astype(int)\n",
|
|
"\n",
|
|
"# Keep top-N most frequent labels to reduce sparsity\n",
|
|
"_top_n = cfg.explainer.lore4groups.top_n_labels\n",
|
|
"top_labels = (\n",
|
|
" tags_df[\"label\"].value_counts().nlargest(_top_n).index.tolist()\n",
|
|
")\n",
|
|
"tags_final = tags_df[tags_df[\"label\"].isin(top_labels)].copy()\n",
|
|
"\n",
|
|
"# Item profiles: {str(itemId): set(labels)}\n",
|
|
"item_profiles = (\n",
|
|
" tags_final.groupby(\"movieId\")[\"label\"].apply(set).to_dict()\n",
|
|
")\n",
|
|
"item_profiles = {str(k): v for k, v in item_profiles.items()}\n",
|
|
"\n",
|
|
"# Item-label matrix (rows: itemId as str, cols: labels, values: 0/1)\n",
|
|
"item_label_matrix = tags_final.assign(value=1).pivot_table(\n",
|
|
" index=\"movieId\", columns=\"label\", values=\"value\", fill_value=0\n",
|
|
")\n",
|
|
"item_label_matrix.index = item_label_matrix.index.astype(str)\n",
|
|
"\n",
|
|
"# 2) Prepare user history in required format\n",
|
|
"# ------------------------------------------\n",
|
|
"user_hist = {}\n",
|
|
"for user_id_orig in group_members:\n",
|
|
" try:\n",
|
|
" user_id_consec = data.get_new_user_id(user_id_orig)\n",
|
|
" hist_items = set(\n",
|
|
" data.dataset[data.dataset[\"userId\"] == user_id_consec][\"itemId\"].astype(str)\n",
|
|
" )\n",
|
|
" user_hist[user_id_orig] = hist_items\n",
|
|
" except Exception:\n",
|
|
" user_hist[user_id_orig] = set()\n",
|
|
"\n",
|
|
"# Filter recommendations to those we can explain (must exist in profiles)\n",
|
|
"explainable_recs = [str(i) for i in recommended_items if str(i) in item_profiles]\n",
|
|
"if not explainable_recs:\n",
|
|
" print(\"⚠️ No recommended items have sufficient tag data for explanation.\")\n",
|
|
"else:\n",
|
|
" # 3) Run LORE4Groups explainer\n",
|
|
" explainer = LORE4GroupsExplainer(\n",
|
|
" item_profiles=item_profiles,\n",
|
|
" item_label_matrix=item_label_matrix,\n",
|
|
" config=cfg,\n",
|
|
" genre_profiles=None, # optional, omitted for toy example\n",
|
|
" )\n",
|
|
"\n",
|
|
" results = explainer.find_explanation(\n",
|
|
" explainable_recs,\n",
|
|
" group_members,\n",
|
|
" user_hist,\n",
|
|
" data.dataset,\n",
|
|
" model=model,\n",
|
|
" data_reader=data,\n",
|
|
" )\n",
|
|
"\n",
|
|
" fidelity = results.get(\"fidelity\", 0.0)\n",
|
|
" details = results.get(\"details\", {})\n",
|
|
"\n",
|
|
" print(\"Explanation Fidelity:\")\n",
|
|
" print(f\"{fidelity:.2%}\")\n",
|
|
" print(\"-\" * 20)\n",
|
|
"\n",
|
|
" # Compute GILD diversity like in the app\n",
|
|
" evaluator = ExplanationEvaluator()\n",
|
|
" metrics = evaluator.evaluate({\"fidelity\": fidelity, \"details\": details}, explainer_type=\"LORE4Groups\")\n",
|
|
" print(\"Explanation Diversity (GILD):\")\n",
|
|
" print(f\"{metrics.get('gild', 0.0):.4f}\")\n",
|
|
" print(\"-\" * 20)\n",
|
|
"\n",
|
|
" print(\"Items with explanations:\")\n",
|
|
" print(list(details.keys()))\n",
|
|
" print(\"-\" * 20)\n",
|
|
"\n",
|
|
" # Optionally preview one item's explanation summary if available\n",
|
|
" if details:\n",
|
|
" first_item, exp = next(iter(details.items()))\n",
|
|
" decision_path = exp.get(\"decision_path\", [])\n",
|
|
" group_factual = exp.get(\"group_factual_rule\", [])\n",
|
|
" print(f\"Sample item: {first_item}\")\n",
|
|
" print(\"Decision Path (rules):\", decision_path)\n",
|
|
" print(\"Group Factual Rules:\", group_factual)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pygrex-exp-grs",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|