Files
py-grex/notebooks/expgrs_toy_example.ipynb
2026-05-22 10:02:10 +02:00

633 lines
21 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "7fa3d250",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6b55c6e8",
"metadata": {},
"outputs": [],
"source": [
"# Local application/library specific imports\n",
"from pygrex.config import cfg\n",
"from pygrex.data_reader import DataReader, GroupInteractionHandler\n",
"# from pygrex.evaluator import SlidingWindowEvaluator\n",
"from pygrex.explain import RuleBasedGroupRecExplainer\n",
"from pygrex.models import ALS\n",
"from pygrex.recommender import GroupRecommender\n",
"from pygrex.utils import AggregationStrategy\n",
"from pygrex.evaluator import ExplanationEvaluator\n",
"\n",
"import time\n",
"import pandas as pd\n",
"import pickle\n",
"import os\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "adbf9967",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Data preparation complete.\n",
"\n",
"--- Data Summary ---\n",
"👥 Unique Users: 610\n",
"📦 Unique Items: 9,724\n",
"⭐ Total Ratings: 100,836\n",
"👨‍👩‍👧‍👦 Number of Groups: 17\n",
"\n",
"Processed Ratings DataFrame Head:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>userId</th>\n",
" <th>itemId</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>964982703</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>964981247</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>964982224</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>43</td>\n",
" <td>1</td>\n",
" <td>964983815</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" <td>964982931</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" userId itemId rating timestamp\n",
"0 0 0 1 964982703\n",
"1 0 2 1 964981247\n",
"2 0 5 1 964982224\n",
"3 0 43 1 964983815\n",
"4 0 46 1 964982931"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Read the ratings file.\n",
"data = DataReader(**cfg.data.test)\n",
"data.make_consecutive_ids_in_dataset()\n",
"data.binarize(binary_threshold=1)\n",
"\n",
"# Read the file with the group ids\n",
"group_handler = GroupInteractionHandler(**cfg.data.groups)\n",
"available_groups = group_handler.read_groups(\"groupsWithHighRatings5.txt\")\n",
"print(\"✅ Data preparation complete.\\n\")\n",
"\n",
"# --- Display Data Summary ---\n",
"print(\"--- Data Summary ---\")\n",
"print(f\"👥 Unique Users: {data.num_user:,}\")\n",
"print(f\"📦 Unique Items: {data.num_item:,}\")\n",
"print(f\"⭐ Total Ratings: {len(data.get_raw_dataset()):,}\")\n",
"print(f\"👨‍👩‍👧‍👦 Number of Groups: {len(available_groups):,}\")\n",
"print(\"\\nProcessed Ratings DataFrame Head:\")\n",
"display(data.dataset.head())"
]
},
{
"cell_type": "markdown",
"id": "5fc94aef",
"metadata": {},
"source": [
"## Step 2: Model Training & Evaluation\n",
"\n",
"With the data prepared, we now select and train a recommendation model. We will use **Alternating Least Squares (ALS)**, a matrix factorization technique for implicit feedback. After training, we will evaluate its performance using a train/test split to measure its Hit Ratio and NDCG."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8c13c283",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- 2.1 Model Training ---\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\usuar\\miniconda3\\envs\\pygrex-exp-grs\\Lib\\site-packages\\implicit\\cpu\\als.py:95: RuntimeWarning: OpenBLAS is configured to use 8 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n",
" check_blas_config()\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a7883a7b94a4a13952cb1d9cf9a33a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Model trained successfully in 1.00 seconds!\n"
]
}
],
"source": [
"print(\"--- 2.1 Model Training ---\")\n",
"\n",
"# Train the recommendation model\n",
"model = ALS(**cfg.model.als)\n",
"\n",
"# Train the model\n",
"start_time = time.time()\n",
"model.fit(data)\n",
"end_time = time.time()\n",
"training_time = end_time - start_time\n",
"\n",
"print(f\"✅ Model trained successfully in {training_time:.2f} seconds!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "047fe521",
"metadata": {},
"outputs": [],
"source": [
"print(\"\\n--- 2.2 Offline Model Evaluation ---\")\n",
"# For evaluation, a new model instance must be created.\n",
"# The evaluation function handles its own internal data splitting and training.\n",
"eval_model = ALS(**cfg.model.als)\n",
"\n",
"# Define evaluation parameters\n",
"test_size = 0.2\n",
"top_n = 10\n",
"\n",
"print(f\"Running evaluation with a {test_size*100:.0f}% test split (Top-{top_n})...\")\n",
"\n",
"# Run the evaluation\n",
"evaluation_scores = run_evaluation_with_proper_split(\n",
" data_reader=data,\n",
" model=eval_model,\n",
" test_size=test_size,\n",
" top_n=top_n,\n",
")\n",
"\n",
"# Display evaluation results\n",
"print(\"\\n--- Evaluation Results ---\")\n",
"print(f\"Hit Ratio @{top_n}: {evaluation_scores.get('Hit Ratio', 0.0):.2%}\")\n",
"print(f\"NDCG @{top_n}: {evaluation_scores.get('NDCG', 0.0):.4f}\")\n",
"print(f\"Evaluation Time: {evaluation_scores.get('evaluation_time', 0):.1f}s\")"
]
},
{
"cell_type": "markdown",
"id": "49cb2659",
"metadata": {},
"source": [
"## Step 3: Group Recommendation\n",
"\n",
"Now that we have a trained model, we can generate recommendations for a group. We will select a group, choose an aggregation strategy to combine individual member preferences, and generate a Top-10 list of recommended items."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0a138815",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- 3. Group Recommendation ---\n",
"Generating Top-10 recommendations for group: 522_385_234_452_594\n",
"👥 Group Members: [522, 385, 234, 452, 594]\n",
"📊 Aggregation Strategy: AVG_PREDICTIONS\n",
"\n",
"✅ Recommendations generated successfully!\n",
"\n",
"Top 10 Recommended Items:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Item ID</th>\n",
" <th>Aggregated Score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>543</td>\n",
" <td>4.636274</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>757</td>\n",
" <td>4.582981</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>564</td>\n",
" <td>4.504107</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>441</td>\n",
" <td>4.488708</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>379</td>\n",
" <td>4.341830</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>475</td>\n",
" <td>4.279482</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>43</td>\n",
" <td>4.268454</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>8</td>\n",
" <td>19</td>\n",
" <td>4.225248</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>748</td>\n",
" <td>4.178329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>64</td>\n",
" <td>4.147735</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Rank Item ID Aggregated Score\n",
"0 1 543 4.636274\n",
"1 2 757 4.582981\n",
"2 3 564 4.504107\n",
"3 4 441 4.488708\n",
"4 5 379 4.341830\n",
"5 6 475 4.279482\n",
"6 7 43 4.268454\n",
"7 8 19 4.225248\n",
"8 9 748 4.178329\n",
"9 10 64 4.147735"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(\"--- 3. Group Recommendation ---\")\n",
"\n",
"# Select a group and strategy\n",
"selected_group_id = available_groups[0] # Let's use the first group as an example\n",
"group_members = group_handler.parse_group_members(selected_group_id)\n",
"aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS # Use the simple average strategy\n",
"top_k = 10\n",
"\n",
"print(f\"Generating Top-{top_k} recommendations for group: {selected_group_id}\")\n",
"print(f\"👥 Group Members: {group_members}\")\n",
"print(f\"📊 Aggregation Strategy: {aggregation_strategy.name}\")\n",
"\n",
"# --- Generate Recommendations ---\n",
"# 1. Instantiate the GroupRecommender\n",
"group_recommender = GroupRecommender(data=data)\n",
"\n",
"# 2. Setup the recommendation process\n",
"group_recommender.setup_recommendation(\n",
" model=model,\n",
" members=group_members, # type: ignore\n",
" data=data,\n",
" aggregation_strategy=aggregation_strategy,\n",
" )\n",
"\n",
"\n",
"# 3. Get the final recommendation list\n",
"recommended_items = group_recommender.get_group_recommendations(top_k=top_k)\n",
"recommendation_scores = group_recommender.get_recommendation_scores()\n",
"\n",
"print(\"\\n✅ Recommendations generated successfully!\")\n",
"\n",
"# --- Display Results ---\n",
"rec_data = [\n",
" {\n",
" \"Rank\": i + 1,\n",
" \"Item ID\": item_id,\n",
" \"Aggregated Score\": recommendation_scores.get(item_id, 0.0),\n",
" }\n",
" for i, item_id in enumerate(recommended_items) # type: ignore\n",
"]\n",
"\n",
"rec_df = pd.DataFrame(rec_data)\n",
"print(f\"\\nTop {top_k} Recommended Items:\")\n",
"display(rec_df)"
]
},
{
"cell_type": "markdown",
"id": "6268a2ed",
"metadata": {},
"source": [
"## Step 4: Explanation (EXPGRS)\n",
"\n",
"Finally, we generate an explanation for one of the recommendations. We will use the **EXPGRS** method to find a ruled based explanation. This method calculates the Model Fidelity: the percentage of the Top-N list that can be explained by pre-computed association rules from cached files.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "367063db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- 4. Rule based Explanation (EXPGRS) ---\n",
"Explanation Fidelity:\n",
"10.00%\n",
"--------------------\n",
"Advanced Explanation Fidelity:\n",
"0.00%\n",
"--------------------\n",
"Explanation Diversity (GILD):\n",
"0.0000\n",
"--------------------\n"
]
}
],
"source": [
"print(\"--- 4. Rule based Explanation (EXPGRS) ---\")\n",
"\n",
"\n",
"def load_cached_data_rules(min_support, min_confidence, rating_threshold):\n",
" \"\"\"\n",
" Loads pre-computed association rules from the cached_rules folder.\n",
" Returns the loaded object (typically a dict with key \"rules\") if found, None otherwise.\n",
" Searches several common locations to be robust in notebooks.\n",
" \"\"\"\n",
" from pathlib import Path\n",
"\n",
" filename = f\"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}\"\n",
" possible_extensions = [\".pkl\", \".pickle\", \".json\"]\n",
"\n",
" cwd = Path.cwd()\n",
" search_dirs = [\n",
" cwd / \"cached_rules\", # current working directory\n",
" cwd.parent / \"cached_rules\", # parent (useful when running from notebooks/)\n",
" Path(__file__).resolve().parent / \"cached_rules\" if '__file__' in globals() else None, # script dir if available\n",
" ]\n",
" search_dirs = [p for p in search_dirs if p is not None]\n",
"\n",
" tried_paths = []\n",
" for base in search_dirs:\n",
" for ext in possible_extensions:\n",
" filepath = base / f\"{filename}{ext}\"\n",
" tried_paths.append(str(filepath))\n",
" if filepath.exists():\n",
" try:\n",
" if ext in [\".pkl\", \".pickle\"]:\n",
" with open(filepath, \"rb\") as f:\n",
" return pickle.load(f)\n",
" elif ext == \".json\":\n",
" import json\n",
" with open(filepath, \"r\") as f:\n",
" return json.load(f)\n",
" except Exception as e:\n",
" print(f\"Error loading cached rules from {filepath}: {e}\")\n",
" continue\n",
"\n",
" print(\"Cached rules not found. Tried paths:\")\n",
" for p in tried_paths:\n",
" print(\" -\", p)\n",
" return None\n",
"\n",
"\n",
"def get_user_history(rating_threshold):\n",
" \"\"\"\n",
" Generates the user interaction history based only on the rating threshold.\n",
" The keys of the returned dictionary are the ORIGINAL user IDs.\n",
" \"\"\"\n",
" df_filtered = data.dataset[data.dataset[\"rating\"] >= rating_threshold]\n",
"\n",
" # Group by the 'userId' column (which contains the new, consecutive IDs)\n",
" history_by_new_id = df_filtered.groupby(\"userId\")[\"itemId\"].apply(set).to_dict()\n",
"\n",
" # Create the final dictionary mapping original user IDs to sets of new item IDs\n",
" history_by_original_id = {}\n",
" for new_id, item_set in history_by_new_id.items():\n",
" try:\n",
" original_id = data.get_original_user_id(int(new_id))\n",
" # The explainer needs the item IDs to be strings to match the rules\n",
" history_by_original_id[original_id] = {str(item) for item in item_set}\n",
" except (ValueError, KeyError):\n",
" continue\n",
"\n",
" return history_by_original_id\n",
"\n",
"# ----------------------------------------------------------------------- #\n",
"\n",
"min_support = 0.1\n",
"min_confidence = 0.1\n",
"rating_threshold = 1\n",
"minimum_members = 1\n",
"\n",
"# Load cached rules (no Streamlit dependencies)\n",
"expected_filename = f\"rules_sup{min_support:.2f}_conf{min_confidence:.1f}_rating{rating_threshold:.0f}\"\n",
"cached_data_rules = load_cached_data_rules(min_support, min_confidence, rating_threshold)\n",
"if cached_data_rules is None:\n",
" print(\"⚠️ Cached rules not found.\")\n",
" print(\"Looked for:\", \", \".join(\n",
" [os.path.join(\"cached_rules\", expected_filename + ext) for ext in [\".pkl\", \".pickle\", \".json\"]]\n",
" ))\n",
" raise SystemExit(\"Please place the cached rules file in the 'cached_rules/' folder.\")\n",
"\n",
"# Extract rules from loaded structure\n",
"cached_rules = cached_data_rules.get(\"rules\") if isinstance(cached_data_rules, dict) else None\n",
"if cached_rules is None:\n",
" raise ValueError(\n",
" \"Loaded cached rules file does not contain a 'rules' key. Check the file format.\"\n",
" )\n",
"\n",
"# Get user history\n",
"user_history = get_user_history(rating_threshold)\n",
"\n",
"# Create explainer with cached rules\n",
"explainer = RuleBasedGroupRecExplainer(\n",
" rules=cached_rules,\n",
" data=data,\n",
" pool_recommendations=recommended_items,\n",
" members=group_members,\n",
" user_history=user_history,\n",
" min_members_threshold=minimum_members,\n",
")\n",
"\n",
"# Compute explanations and metrics\n",
"fidelity_score = explainer.find_explanation()\n",
"advanced_fidelity_score = explainer.compute_group_fidelity_advanced()\n",
"explanation_details = explainer.get_explanation_details()\n",
"\n",
"explanation_results = {\n",
" \"fidelity\": fidelity_score,\n",
" \"advanced_fidelity\": advanced_fidelity_score,\n",
" \"details\": explanation_details,\n",
"}\n",
"\n",
"# Evaluate results\n",
"_evaluator = ExplanationEvaluator()\n",
"metrics = _evaluator.evaluate(explanation_results, explainer_type=\"EXPGRS\")\n",
"\n",
"print(\"Explanation Fidelity:\")\n",
"print(f\"{metrics.get('fidelity', 0.0):.2%}\")\n",
"print(\"-\" * 20)\n",
"\n",
"print(\"Advanced Explanation Fidelity:\")\n",
"print(f\"{explanation_results.get('advanced_fidelity', 0.0):.2%}\")\n",
"print(\"-\" * 20)\n",
"\n",
"print(\"Explanation Diversity (GILD):\")\n",
"print(f\"{metrics.get('gild', 0.0):.4f}\")\n",
"print(\"-\" * 20)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pygrex-exp-grs",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}