576 lines
18 KiB
Plaintext
576 lines
18 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7fa3d250",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "6b55c6e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Local application/library specific imports\n",
|
|
"from pygrex.config import cfg\n",
|
|
"from pygrex.data_reader import DataReader, GroupInteractionHandler\n",
|
|
"# from pygrex.evaluator import SlidingWindowEvaluator\n",
|
|
"from pygrex.explain import SlidingWindowExplainer\n",
|
|
"from pygrex.models import ALS\n",
|
|
"from pygrex.recommender import GroupRecommender\n",
|
|
"from pygrex.utils import SlidingWindow, AggregationStrategy\n",
|
|
"from pygrex.evaluator import run_evaluation_with_proper_split\n",
|
|
"\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import pandas as pd\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "adbf9967",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Data preparation complete.\n",
|
|
"\n",
|
|
"--- Data Summary ---\n",
|
|
"👥 Unique Users: 610\n",
|
|
"📦 Unique Items: 9,724\n",
|
|
"⭐ Total Ratings: 100,836\n",
|
|
"👨👩👧👦 Number of Groups: 17\n",
|
|
"\n",
|
|
"Processed Ratings DataFrame Head:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>userId</th>\n",
|
|
" <th>itemId</th>\n",
|
|
" <th>rating</th>\n",
|
|
" <th>timestamp</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982703</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964981247</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982224</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964983815</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>46</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>964982931</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" userId itemId rating timestamp\n",
|
|
"0 0 0 1 964982703\n",
|
|
"1 0 2 1 964981247\n",
|
|
"2 0 5 1 964982224\n",
|
|
"3 0 43 1 964983815\n",
|
|
"4 0 46 1 964982931"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read the ratings file.\n",
|
|
"data = DataReader(**cfg.data.test)\n",
|
|
"data.make_consecutive_ids_in_dataset()\n",
|
|
"data.binarize(binary_threshold=1)\n",
|
|
"\n",
|
|
"# Read the file with the group ids\n",
|
|
"group_handler = GroupInteractionHandler(**cfg.data.groups)\n",
|
|
"available_groups = group_handler.read_groups(\"groupsWithHighRatings5.txt\")\n",
|
|
"print(\"✅ Data preparation complete.\\n\")\n",
|
|
"\n",
|
|
"# --- Display Data Summary ---\n",
|
|
"print(\"--- Data Summary ---\")\n",
|
|
"print(f\"👥 Unique Users: {data.num_user:,}\")\n",
|
|
"print(f\"📦 Unique Items: {data.num_item:,}\")\n",
|
|
"print(f\"⭐ Total Ratings: {len(data.get_raw_dataset()):,}\")\n",
|
|
"print(f\"👨👩👧👦 Number of Groups: {len(available_groups):,}\")\n",
|
|
"print(\"\\nProcessed Ratings DataFrame Head:\")\n",
|
|
"display(data.dataset.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5fc94aef",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 2: Model Training & Evaluation\n",
|
|
"\n",
|
|
"With the data prepared, we now select and train a recommendation model. We will use **Alternating Least Squares (ALS)**, a matrix factorization technique for implicit feedback. After training, we will evaluate its performance using a train/test split to measure its Hit Ratio and NDCG."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8c13c283",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 2.1 Model Training ---\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\usuar\\miniconda3\\envs\\pygrex-exp-grs\\Lib\\site-packages\\implicit\\cpu\\als.py:95: RuntimeWarning: OpenBLAS is configured to use 8 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n",
|
|
" check_blas_config()\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "4a2c30b182994b868f98ba8d9d2d7d8f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Model trained successfully in 1.08 seconds!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 2.1 Model Training ---\")\n",
|
|
"\n",
|
|
"# Train the recommendation model\n",
|
|
"model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Train the model\n",
|
|
"start_time = time.time()\n",
|
|
"model.fit(data)\n",
|
|
"end_time = time.time()\n",
|
|
"training_time = end_time - start_time\n",
|
|
"\n",
|
|
"print(f\"✅ Model trained successfully in {training_time:.2f} seconds!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "047fe521",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"\\n--- 2.2 Offline Model Evaluation ---\")\n",
|
|
"# For evaluation, a new model instance must be created.\n",
|
|
"# The evaluation function handles its own internal data splitting and training.\n",
|
|
"eval_model = ALS(**cfg.model.als)\n",
|
|
"\n",
|
|
"# Define evaluation parameters\n",
|
|
"test_size = 0.2\n",
|
|
"top_n = 10\n",
|
|
"\n",
|
|
"print(f\"Running evaluation with a {test_size*100:.0f}% test split (Top-{top_n})...\")\n",
|
|
"\n",
|
|
"# Run the evaluation\n",
|
|
"evaluation_scores = run_evaluation_with_proper_split(\n",
|
|
" data_reader=data,\n",
|
|
" model=eval_model,\n",
|
|
" test_size=test_size,\n",
|
|
" top_n=top_n,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Display evaluation results\n",
|
|
"print(\"\\n--- Evaluation Results ---\")\n",
|
|
"print(f\"Hit Ratio @{top_n}: {evaluation_scores.get('Hit Ratio', 0.0):.2%}\")\n",
|
|
"print(f\"NDCG @{top_n}: {evaluation_scores.get('NDCG', 0.0):.4f}\")\n",
|
|
"print(f\"Evaluation Time: {evaluation_scores.get('evaluation_time', 0):.1f}s\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "49cb2659",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 3: Group Recommendation\n",
|
|
"\n",
|
|
"Now that we have a trained model, we can generate recommendations for a group. We will select a group, choose an aggregation strategy to combine individual member preferences, and generate a Top-10 list of recommended items."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "0a138815",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 3. Group Recommendation ---\n",
|
|
"Generating Top-10 recommendations for group: 522_385_234_452_594\n",
|
|
"👥 Group Members: [522, 385, 234, 452, 594]\n",
|
|
"📊 Aggregation Strategy: AVG_PREDICTIONS\n",
|
|
"\n",
|
|
"✅ Recommendations generated successfully!\n",
|
|
"\n",
|
|
"Top 10 Recommended Items:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Rank</th>\n",
|
|
" <th>Item ID</th>\n",
|
|
" <th>Aggregated Score</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>543</td>\n",
|
|
" <td>4.636274</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>757</td>\n",
|
|
" <td>4.582981</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>564</td>\n",
|
|
" <td>4.504107</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>441</td>\n",
|
|
" <td>4.488708</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>5</td>\n",
|
|
" <td>379</td>\n",
|
|
" <td>4.341830</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>6</td>\n",
|
|
" <td>475</td>\n",
|
|
" <td>4.279482</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>7</td>\n",
|
|
" <td>43</td>\n",
|
|
" <td>4.268454</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>8</td>\n",
|
|
" <td>19</td>\n",
|
|
" <td>4.225248</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>9</td>\n",
|
|
" <td>748</td>\n",
|
|
" <td>4.178329</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>10</td>\n",
|
|
" <td>64</td>\n",
|
|
" <td>4.147735</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Rank Item ID Aggregated Score\n",
|
|
"0 1 543 4.636274\n",
|
|
"1 2 757 4.582981\n",
|
|
"2 3 564 4.504107\n",
|
|
"3 4 441 4.488708\n",
|
|
"4 5 379 4.341830\n",
|
|
"5 6 475 4.279482\n",
|
|
"6 7 43 4.268454\n",
|
|
"7 8 19 4.225248\n",
|
|
"8 9 748 4.178329\n",
|
|
"9 10 64 4.147735"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 3. Group Recommendation ---\")\n",
|
|
"\n",
|
|
"# Select a group and strategy\n",
|
|
"selected_group_id = available_groups[0] # Let's use the first group as an example\n",
|
|
"group_members = group_handler.parse_group_members(selected_group_id)\n",
|
|
"aggregation_strategy = AggregationStrategy.AVG_PREDICTIONS # Use the simple average strategy\n",
|
|
"top_k = 10\n",
|
|
"\n",
|
|
"print(f\"Generating Top-{top_k} recommendations for group: {selected_group_id}\")\n",
|
|
"print(f\"👥 Group Members: {group_members}\")\n",
|
|
"print(f\"📊 Aggregation Strategy: {aggregation_strategy.name}\")\n",
|
|
"\n",
|
|
"# --- Generate Recommendations ---\n",
|
|
"# 1. Instantiate the GroupRecommender\n",
|
|
"group_recommender = GroupRecommender(data=data)\n",
|
|
"\n",
|
|
"# 2. Setup the recommendation process\n",
|
|
"group_recommender.setup_recommendation(\n",
|
|
" model=model,\n",
|
|
" members=group_members, # type: ignore\n",
|
|
" data=data,\n",
|
|
" aggregation_strategy=aggregation_strategy,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# 3. Get the final recommendation list\n",
|
|
"recommended_items = group_recommender.get_group_recommendations(top_k=top_k)\n",
|
|
"recommendation_scores = group_recommender.get_recommendation_scores()\n",
|
|
"\n",
|
|
"print(\"\\n✅ Recommendations generated successfully!\")\n",
|
|
"\n",
|
|
"# --- Display Results ---\n",
|
|
"rec_data = [\n",
|
|
" {\n",
|
|
" \"Rank\": i + 1,\n",
|
|
" \"Item ID\": item_id,\n",
|
|
" \"Aggregated Score\": recommendation_scores.get(item_id, 0.0),\n",
|
|
" }\n",
|
|
" for i, item_id in enumerate(recommended_items) # type: ignore\n",
|
|
"]\n",
|
|
"\n",
|
|
"rec_df = pd.DataFrame(rec_data)\n",
|
|
"print(f\"\\nTop {top_k} Recommended Items:\")\n",
|
|
"display(rec_df)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6268a2ed",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 4: Explanation (Sliding Window)\n",
|
|
"\n",
|
|
"Finally, we generate an explanation for one of the recommendations. We will use the **Sliding Window** method to find a counterfactual explanation. This method answers the question: *\"Which minimal set of items, if removed from the group's history, would cause our target item to disappear from the recommendation list?\"*\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "367063db",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--- 4. Counterfactual Explanation (Sliding Window) ---\n",
|
|
"Generating explanation for recommended item: 543\n",
|
|
"Sliding Window Size: 3\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6c1d8da938db475ab005c2378f99feae",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "221213362ff442f7b9e736536844c6b9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"If the group had not interacted with these items [np.int64(480)],\n",
|
|
"the item of interest 543 would not have appeared on the recommendation list;\n",
|
|
"instead, 303 would have been recommended.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"--- 4. Counterfactual Explanation (Sliding Window) ---\")\n",
|
|
"\n",
|
|
"# Select a target item from our recommendation list to explain\n",
|
|
"target_item = recommended_items[0]\n",
|
|
"# Configure the explainer\n",
|
|
"window_size = 3\n",
|
|
"# These weights determine how to rank items from the group's history\n",
|
|
"# before attempting to remove them to find an explanation.\n",
|
|
"ranking_weights = {\n",
|
|
" \"popularity\": 1.0,\n",
|
|
" \"intensity\": 1.0,\n",
|
|
" \"rating\": 1.0,\n",
|
|
" \"relevance\": 1.0,\n",
|
|
" \"trend\": 1.0,\n",
|
|
"}\n",
|
|
"\n",
|
|
"print(f\"Generating explanation for recommended item: {target_item}\")\n",
|
|
"print(f\"Sliding Window Size: {window_size}\\n\")\n",
|
|
"\n",
|
|
"# --- Generate Explanation ---\n",
|
|
"# 1. Get all items previously rated by the group\n",
|
|
"items_rated_by_group = group_handler.get_rated_items_by_all_group_members(\n",
|
|
" group=group_members, original_data=data\n",
|
|
")\n",
|
|
"\n",
|
|
"# 2. Instantiate the explainer\n",
|
|
"explainer = SlidingWindowExplainer(\n",
|
|
" config=cfg, # Not needed for this explainer\n",
|
|
" data=data,\n",
|
|
" group_handler=group_handler,\n",
|
|
" members=group_members,\n",
|
|
" target_item=target_item,\n",
|
|
" aggregation_strategy=aggregation_strategy,\n",
|
|
" model=model,\n",
|
|
" window_size=window_size,\n",
|
|
")\n",
|
|
"\n",
|
|
"# 3. Find the explanation\n",
|
|
"explanations = explainer.find_explanation(\n",
|
|
" items_rated_by_group=items_rated_by_group,\n",
|
|
" group_predictions=group_recommender.get_individual_predictions(),\n",
|
|
" top_recommendation=group_recommender.get_top_recommendation(),\n",
|
|
" ranking_weights=ranking_weights,\n",
|
|
")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pygrex-exp-grs",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|