633 lines
21 KiB
Plaintext
633 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7fa3d250",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "6b55c6e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Local application/library specific imports\n",
|
|
"from pygrex.config import cfg\n",
|
|
"from pygrex.data_reader import DataReader, GroupInteractionHandler\n",
|
|
"# from pygrex.evaluator import SlidingWindowEvaluator\n",
|
|
"from pygrex.explain import RuleBasedGroupRecExplainer\n",
|
|
"from pygrex.models import ALS\n",
|
|
"from pygrex.recommender import GroupRecommender\n",
|
|
"from pygrex.utils import AggregationStrategy\n",
|
|
"from pygrex.evaluator import ExplanationEvaluator\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import pandas as pd\n",
|
|
"import pickle\n",
|
|
"import os\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "adbf9967",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Data preparation complete.\n",
|
|
"\n",
|
|
"--- Data Summary ---\n",
|
|
"👥 Unique Users: 610\n",
|
|
"📦 Unique Items: 9,724\n",
|
|
"⭐ Total Ratings: 100,836\n",
|
|
"👨👩👧👦 Number of Groups: 17\n",
|
|
"\n",
|
|
"Processed Ratings DataFrame Head:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>userId</th>\n",
|
|
" <th>itemId</th>\n",
|
|
" <th>rating</th>\n",
|
|
" <th>timestamp</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982703</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964981247</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982224</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964983815</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>46</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982931</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" userId itemId rating timestamp\n",
|
|
"0 0 0 1 964982703\n",
|
|
"1 0 2 1 964981247\n",
|
|
"2 0 5 1 964982224\n",
|
|
"3 0 43 1 964983815\n",
|
|
"4 0 46 1 964982931"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read the ratings file.\n",
|
|
"data = DataReader(**cfg.data.test)\n",
|
|
"data.make_consecutive_ids_in_dataset()\n",
|
|
"data.binarize(binary_threshold=1)\n",
|
|
"\n",
|
|
"# Read the file with the group ids\n",
|
|
"group_handler = GroupInteractionHandler(**cfg.data.groups)\n",
|
|
"available_groups = group_handler.read_groups(\"groupsWithHighRatings5.txt\")\n",
|
|
"print(\"✅ Data preparation complete.\\n\")\n",
|
|
"\n",
|
|
"# --- Display Data Summary ---\n",
|
|
"print(\"--- Data Summary ---\")\n",
|
|
"print(f\"👥 Unique Users: {data.num_user:,}\")\n",
|
|
"print(f\"📦 Unique Items: {data.num_item:,}\")\n",
|
|
"print(f\"⭐ Total Ratings: {len(data.get_raw_dataset()):,}\")\n",
|
|
"print(f\"👨👩👧👦 Number of Groups: {len(available_groups):,}\")\n",
|
|
"print(\"\\nProcessed Ratings DataFrame Head:\")\n",
|
|
"display(data.dataset.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5fc94aef",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 2: Model Training & Evaluation\n",
|
|
"\n",
|
|
"With the data prepared, we now select and train a recommendation model. We will use **Alternating Least Squares (ALS)**, a matrix factorization technique for implicit feedback. After training, we will evaluate its performance using a train/test split to measure its Hit Ratio and NDCG."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8c13c283",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 2.1 Model Training ---\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\usuar\\miniconda3\\envs\\pygrex-exp-grs\\Lib\\site-packages\\implicit\\cpu\\als.py:95: RuntimeWarning: OpenBLAS is configured to use 8 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n",
|
|
" check_blas_config()\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "4a7883a7b94a4a13952cb1d9cf9a33a4",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Model trained successfully in 1.00 seconds!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 2.1 Model Training ---\")\n",
|
|
"\n",
|
|
"# Train the recommendation model\n",
|
|
"model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Train the model\n",
|
|
"start_time = time.time()\n",
|
|
"model.fit(data)\n",
|
|
"end_time = time.time()\n",
|
|
"training_time = end_time - start_time\n",
|
|
"\n",
|
|
"print(f\"✅ Model trained successfully in {training_time:.2f} seconds!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "047fe521",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"\\n--- 2.2 Offline Model Evaluation ---\")\n",
|
|
"# For evaluation, a new model instance must be created.\n",
|
|
"# The evaluation function handles its own internal data splitting and training.\n",
|
|
"eval_model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Define evaluation parameters\n",
|
|
"test_size = 0.2\n",
|
|
"top_n = 10\n",
|
|
"\n",
|
|
"print(f\"Running evaluation with a {test_size*100:.0f}% test split (Top-{top_n})...\")\n",
|
|
"\n",
|
|
"# Run the evaluation\n",
|
|
"evaluation_scores = run_evaluation_with_proper_split(\n",
|
|
" data_reader=data,\n",
|
|
" model=eval_model,\n",
|
|
" test_size=test_size,\n",
|
|
" top_n=top_n,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Display evaluation results\n",
|
|
"print(\"\\n--- Evaluation Results ---\")\n",
|
|
"print(f\"Hit Ratio @{top_n}: {evaluation_scores.get('Hit Ratio', 0.0):.2%}\")\n",
|
|
"print(f\"NDCG @{top_n}: {evaluation_scores.get('NDCG', 0.0):.4f}\")\n",
|
|
"print(f\"Evaluation Time: {evaluation_scores.get('evaluation_time', 0):.1f}s\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "49cb2659",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 3: Group Recommendation\n",
|
|
"\n",
|
|
"Now that we have a trained model, we can generate recommendations for a group. We will select a group, choose an aggregation strategy to combine individual member preferences, and generate a Top-10 list of recommended items."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "0a138815",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 3. Group Recommendation ---\n",
|
|
"Generating Top-10 recommendations for group: 522_385_234_452_594\n",
|
|
"👥 Group Members: [522, 385, 234, 452, 594]\n",
|
|
"📊 Aggregation Strategy: AVG_PREDICTIONS\n",
|
|
"\n",
|
|
"✅ Recommendations generated successfully!\n",
|
|
"\n",
|
|
"Top 10 Recommended Items:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Rank</th>\n",
|
|
" <th>Item ID</th>\n",
|
|
" <th>Aggregated Score</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>543</td>\n",
|
|
" <td>4.636274</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>757</td>\n",
|
|
" <td>4.582981</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>564</td>\n",
|
|
" <td>4.504107</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>441</td>\n",
|
|
" <td>4.488708</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>5</td>\n",
|
|
" <td>379</td>\n",
|
|
" <td>4.341830</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>6</td>\n",
|
|
" <td>475</td>\n",
|
|
" <td>4.279482</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>7</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>4.268454</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>8</td>\n",
|
|
" <td>19</td>\n",
|
|
" <td>4.225248</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>9</td>\n",
|
|
" <td>748</td>\n",
|
|
" <td>4.178329</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>10</td>\n",
|
|
" <td>64</td>\n",
|
|
" <td>4.147735</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Rank Item ID Aggregated Score\n",
|
|
"0 1 543 4.636274\n",
|
|
"1 2 757 4.582981\n",
|
|
"2 3 564 4.504107\n",
|
|
"3 4 441 4.488708\n",
|
|
"4 5 379 4.341830\n",
|
|
"5 6 475 4.279482\n",
|
|
"6 7 43 4.268454\n",
|
|
"7 8 19 4.225248\n",
|
|
"8 9 748 4.178329\n",
|
|
"9 10 64 4.147735"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 3. Group Recommendation ---\")\n",
|
|
"\n",
|
|
"# Select a group and strategy\n",
|
|
"selected_group_id = available_groups[0] # Let's use the first group as an example\n",
|
|
"group_members = group_handler.parse_group_members(selected_group_id)\n",
|
|
"aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS # Use the simple average strategy\n",
|
|
"top_k = 10\n",
|
|
"\n",
|
|
"print(f\"Generating Top-{top_k} recommendations for group: {selected_group_id}\")\n",
|
|
"print(f\"👥 Group Members: {group_members}\")\n",
|
|
"print(f\"📊 Aggregation Strategy: {aggregation_strategy.name}\")\n",
|
|
"\n",
|
|
"# --- Generate Recommendations ---\n",
|
|
"# 1. Instantiate the GroupRecommender\n",
|
|
"group_recommender = GroupRecommender(data=data)\n",
|
|
"\n",
|
|
"# 2. Setup the recommendation process\n",
|
|
"group_recommender.setup_recommendation(\n",
|
|
" model=model,\n",
|
|
" members=group_members, # type: ignore\n",
|
|
" data=data,\n",
|
|
" aggregation_strategy=aggregation_strategy,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# 3. Get the final recommendation list\n",
|
|
"recommended_items = group_recommender.get_group_recommendations(top_k=top_k)\n",
|
|
"recommendation_scores = group_recommender.get_recommendation_scores()\n",
|
|
"\n",
|
|
"print(\"\\n✅ Recommendations generated successfully!\")\n",
|
|
"\n",
|
|
"# --- Display Results ---\n",
|
|
"rec_data = [\n",
|
|
" {\n",
|
|
" \"Rank\": i + 1,\n",
|
|
" \"Item ID\": item_id,\n",
|
|
" \"Aggregated Score\": recommendation_scores.get(item_id, 0.0),\n",
|
|
" }\n",
|
|
" for i, item_id in enumerate(recommended_items) # type: ignore\n",
|
|
"]\n",
|
|
"\n",
|
|
"rec_df = pd.DataFrame(rec_data)\n",
|
|
"print(f\"\\nTop {top_k} Recommended Items:\")\n",
|
|
"display(rec_df)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6268a2ed",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 4: Explanation (EXPGRS)\n",
|
|
"\n",
|
|
"Finally, we generate an explanation for one of the recommendations. We will use the **EXPGRS** method to find a ruled based explanation. This method calculates the Model Fidelity: the percentage of the Top-N list that can be explained by pre-computed association rules from cached files.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "367063db",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 4. Rule based Explanation (EXPGRS) ---\n",
|
|
"Explanation Fidelity:\n",
|
|
"10.00%\n",
|
|
"--------------------\n",
|
|
"Advanced Explanation Fidelity:\n",
|
|
"0.00%\n",
|
|
"--------------------\n",
|
|
"Explanation Diversity (GILD):\n",
|
|
"0.0000\n",
|
|
"--------------------\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 4. Rule based Explanation (EXPGRS) ---\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_cached_data_rules(min_support, min_confidence, rating_threshold):\n",
|
|
" \"\"\"\n",
|
|
" Loads pre-computed association rules from the cached_rules folder.\n",
|
|
" Returns the loaded object (typically a dict with key \"rules\") if found, None otherwise.\n",
|
|
" Searches several common locations to be robust in notebooks.\n",
|
|
" \"\"\"\n",
|
|
" from pathlib import Path\n",
|
|
"\n",
|
|
" filename = f\"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}\"\n",
|
|
" possible_extensions = [\".pkl\", \".pickle\", \".json\"]\n",
|
|
"\n",
|
|
" cwd = Path.cwd()\n",
|
|
" search_dirs = [\n",
|
|
" cwd / \"cached_rules\", # current working directory\n",
|
|
" cwd.parent / \"cached_rules\", # parent (useful when running from notebooks/)\n",
|
|
" Path(__file__).resolve().parent / \"cached_rules\" if '__file__' in globals() else None, # script dir if available\n",
|
|
" ]\n",
|
|
" search_dirs = [p for p in search_dirs if p is not None]\n",
|
|
"\n",
|
|
" tried_paths = []\n",
|
|
" for base in search_dirs:\n",
|
|
" for ext in possible_extensions:\n",
|
|
" filepath = base / f\"{filename}{ext}\"\n",
|
|
" tried_paths.append(str(filepath))\n",
|
|
" if filepath.exists():\n",
|
|
" try:\n",
|
|
" if ext in [\".pkl\", \".pickle\"]:\n",
|
|
" with open(filepath, \"rb\") as f:\n",
|
|
" return pickle.load(f)\n",
|
|
" elif ext == \".json\":\n",
|
|
" import json\n",
|
|
" with open(filepath, \"r\") as f:\n",
|
|
" return json.load(f)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error loading cached rules from {filepath}: {e}\")\n",
|
|
" continue\n",
|
|
"\n",
|
|
" print(\"Cached rules not found. Tried paths:\")\n",
|
|
" for p in tried_paths:\n",
|
|
" print(\" -\", p)\n",
|
|
" return None\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_user_history(rating_threshold):\n",
|
|
" \"\"\"\n",
|
|
" Generates the user interaction history based only on the rating threshold.\n",
|
|
" The keys of the returned dictionary are the ORIGINAL user IDs.\n",
|
|
" \"\"\"\n",
|
|
" df_filtered = data.dataset[data.dataset[\"rating\"] >= rating_threshold]\n",
|
|
"\n",
|
|
" # Group by the 'userId' column (which contains the new, consecutive IDs)\n",
|
|
" history_by_new_id = df_filtered.groupby(\"userId\")[\"itemId\"].apply(set).to_dict()\n",
|
|
"\n",
|
|
" # Create the final dictionary mapping original user IDs to sets of new item IDs\n",
|
|
" history_by_original_id = {}\n",
|
|
" for new_id, item_set in history_by_new_id.items():\n",
|
|
" try:\n",
|
|
" original_id = data.get_original_user_id(int(new_id))\n",
|
|
" # The explainer needs the item IDs to be strings to match the rules\n",
|
|
" history_by_original_id[original_id] = {str(item) for item in item_set}\n",
|
|
" except (ValueError, KeyError):\n",
|
|
" continue\n",
|
|
"\n",
|
|
" return history_by_original_id\n",
|
|
"\n",
|
|
"# ----------------------------------------------------------------------- #\n",
|
|
"\n",
|
|
"min_support = 0.1\n",
|
|
"min_confidence = 0.1\n",
|
|
"rating_threshold = 1\n",
|
|
"minimum_members = 1\n",
|
|
"\n",
|
|
"# Load cached rules (no Streamlit dependencies)\n",
|
|
"expected_filename = f\"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}\"\n",
|
|
"cached_data_rules = load_cached_data_rules(min_support, min_confidence, rating_threshold)\n",
|
|
"if cached_data_rules is None:\n",
|
|
" print(\"⚠️ Cached rules not found.\")\n",
|
|
" print(\"Looked for:\", \", \".join(\n",
|
|
" [os.path.join(\"cached_rules\", expected_filename + ext) for ext in [\".pkl\", \".pickle\", \".json\"]]\n",
|
|
" ))\n",
|
|
" raise SystemExit(\"Please place the cached rules file in the 'cached_rules/' folder.\")\n",
|
|
"\n",
|
|
"# Extract rules from loaded structure\n",
|
|
"cached_rules = cached_data_rules.get(\"rules\") if isinstance(cached_data_rules, dict) else None\n",
|
|
"if cached_rules is None:\n",
|
|
" raise ValueError(\n",
|
|
" \"Loaded cached rules file does not contain a 'rules' key. Check the file format.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Get user history\n",
|
|
"user_history = get_user_history(rating_threshold)\n",
|
|
"\n",
|
|
"# Create explainer with cached rules\n",
|
|
"explainer = RuleBasedGroupRecExplainer(\n",
|
|
" rules=cached_rules,\n",
|
|
" data=data,\n",
|
|
" pool_recommendations=recommended_items,\n",
|
|
" members=group_members,\n",
|
|
" user_history=user_history,\n",
|
|
" min_members_threshold=minimum_members,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Compute explanations and metrics\n",
|
|
"fidelity_score = explainer.find_explanation()\n",
|
|
"advanced_fidelity_score = explainer.compute_group_fidelity_advanced()\n",
|
|
"explanation_details = explainer.get_explanation_details()\n",
|
|
"\n",
|
|
"explanation_results = {\n",
|
|
" \"fidelity\": fidelity_score,\n",
|
|
" \"advanced_fidelity\": advanced_fidelity_score,\n",
|
|
" \"details\": explanation_details,\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Evaluate results\n",
|
|
"_evaluator = ExplanationEvaluator()\n",
|
|
"metrics = _evaluator.evaluate(explanation_results, explainer_type=\"EXPGRS\")\n",
|
|
"\n",
|
|
"print(\"Explanation Fidelity:\")\n",
|
|
"print(f\"{metrics.get('fidelity', 0.0):.2%}\")\n",
|
|
"print(\"-\" * 20)\n",
|
|
"\n",
|
|
"print(\"Advanced Explanation Fidelity:\")\n",
|
|
"print(f\"{explanation_results.get('advanced_fidelity', 0.0):.2%}\")\n",
|
|
"print(\"-\" * 20)\n",
|
|
"\n",
|
|
"print(\"Explanation Diversity (GILD):\")\n",
|
|
"print(f\"{metrics.get('gild', 0.0):.4f}\")\n",
|
|
"print(\"-\" * 20)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pygrex-exp-grs",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|