Files
py-grex/pygrex/utils/association_rules.py
2026-05-22 10:02:10 +02:00

256 lines
9.0 KiB
Python

from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import fpgrowth, association_rules
import pandas as pd
from pygrex.data_reader.data_reader import DataReader
from typing import List, Optional, Union
class AssociationRules:
"""
A class to represent association rules mining for recommendation systems.
This class implements association rules mining using the FP-Growth algorithm
to discover frequent itemsets and generate association rules from user-item
interaction data. It can be used to find patterns in user behavior and
generate item recommendations based on item associations.
"""
def __init__(
self,
data: DataReader,
min_support: float = 0.2,
min_confidence: float = 0.2,
rating_threshold: float = 4.0,
) -> None:
"""Initialize the association rules miner with data and parameters.
Args:
data: The DataReader object containing user-item interactions with ratings.
min_support: Minimum support threshold for frequent itemsets.
Must be between 0 and 1. Default is 0.2.
min_confidence: Minimum confidence threshold for association rules.
Must be between 0 and 1. Default is 0.2.
rating_threshold: Minimum rating threshold to consider an interaction
as positive. Default is 4.0.
Raises:
ValueError: If support, confidence, or rating_threshold values are invalid.
"""
self._validate_parameters(min_support, min_confidence, rating_threshold)
self.data = data
self.min_support = min_support
self.min_confidence = min_confidence
self.rating_threshold = rating_threshold
self._frequent_itemsets: Optional[pd.DataFrame] = None
self._association_rules: Optional[pd.DataFrame] = None
def _validate_parameters(
self, min_support: float, min_confidence: float, rating_threshold: float
) -> None:
"""Validate initialization parameters.
Args:
min_support: Minimum support threshold to validate.
min_confidence: Minimum confidence threshold to validate.
rating_threshold: Rating threshold to validate.
Raises:
ValueError: If any parameter is invalid.
"""
if not (0 < min_support <= 1):
raise ValueError("min_support must be between 0 and 1")
if not (0 < min_confidence <= 1):
raise ValueError("min_confidence must be between 0 and 1")
if rating_threshold < 0:
raise ValueError("rating_threshold must be non-negative")
def get_df_filtered_by_rating_threshold(self) -> pd.DataFrame:
df = self.data.dataset.copy()
# Filter interactions based on rating threshold
df_filtered = df[df["rating"] >= self.rating_threshold]
if df_filtered.empty:
raise ValueError(
f"No interactions found with rating >= {self.rating_threshold}"
)
return df_filtered
def _prepare_transactions(self) -> List[List[str]]:
"""Prepare transaction data from the dataset.
Filters the dataset based on rating threshold and groups items
by user to create transaction lists.
Returns:
A list of transactions, where each transaction is a list of item IDs
that a user has positively interacted with.
"""
df_filtered = self.get_df_filtered_by_rating_threshold()
# Group items by user to create transactions
transactions = df_filtered.groupby("userId")["itemId"].apply(list).tolist()
# Convert item IDs to strings for consistency
transactions = [
[str(item) for item in transaction] for transaction in transactions
]
return transactions
def _mine_frequent_itemsets(
self, transactions: List[List[Union[str, int]]]
) -> pd.DataFrame:
"""Mine frequent itemsets using FP-Growth algorithm.
Args:
transactions: List of transactions to mine frequent itemsets from.
Returns:
DataFrame containing frequent itemsets with their support values.
Raises:
ValueError: If no frequent itemsets are found.
"""
# Encode transactions into binary matrix
transaction_encoder = TransactionEncoder()
transaction_matrix = transaction_encoder.fit_transform(transactions)
df_encoded = pd.DataFrame(
transaction_matrix, # type: ignore
columns=transaction_encoder.columns_,
)
# Apply FP-Growth to find frequent itemsets
frequent_itemsets = fpgrowth(
df_encoded, min_support=self.min_support, use_colnames=True
)
if frequent_itemsets.empty:
raise ValueError(
f"No frequent itemsets found with min_support={self.min_support}"
)
return frequent_itemsets
def _generate_association_rules(
self, frequent_itemsets: pd.DataFrame
) -> pd.DataFrame:
"""Generate association rules from frequent itemsets.
Args:
frequent_itemsets: DataFrame containing frequent itemsets.
Returns:
DataFrame containing association rules with their metrics.
Raises:
ValueError: If no association rules are found.
"""
rules = association_rules(
frequent_itemsets, metric="confidence", min_threshold=self.min_confidence
)
if rules.empty:
raise ValueError(
f"No association rules found with min_confidence={self.min_confidence}"
)
return rules
def compute(self) -> pd.DataFrame:
"""Compute association rules from the dataset.
This method performs the complete association rules mining process:
1. Prepares transactions from the dataset
2. Mines frequent itemsets using FP-Growth
3. Generates association rules from frequent itemsets
Returns:
DataFrame containing association rules with metrics including
antecedents, consequents, support, confidence, lift, etc.
Raises:
ValueError: If the dataset is empty, no transactions meet the
criteria, or no rules can be generated with the given parameters.
"""
if self.data.dataset.empty:
raise ValueError("Dataset is empty")
# Prepare transactions
transactions = self._prepare_transactions()
if not transactions:
raise ValueError("No transactions found after filtering")
# Mine frequent itemsets
self._frequent_itemsets = self._mine_frequent_itemsets(transactions) # type: ignore
# Generate association rules
self._association_rules = self._generate_association_rules(
self._frequent_itemsets
)
return self._association_rules
def get_frequent_itemsets(self) -> Optional[pd.DataFrame]:
"""Get the computed frequent itemsets.
Returns:
DataFrame containing frequent itemsets if compute() has been called,
None otherwise.
"""
return self._frequent_itemsets
def get_recommendations_for_items(
self, items: List[Union[str, int]], top_k: int = 10
) -> pd.DataFrame:
"""Get item recommendations based on association rules.
Args:
items: List of item IDs to get recommendations for.
top_k: Maximum number of recommendations to return. Default is 10.
Returns:
DataFrame containing recommended items sorted by confidence.
Raises:
RuntimeError: If compute() hasn't been called yet.
ValueError: If items list is empty.
"""
if self._association_rules is None:
raise RuntimeError("Must call compute() before getting recommendations")
if not items:
raise ValueError("Items list cannot be empty")
items_set = set(str(item) for item in items)
# Filter rules where antecedents match the given items
matching_rules = self._association_rules[
self._association_rules["antecedents"].apply(
lambda x: items_set.issubset(set(str(item) for item in x))
)
]
if matching_rules.empty:
return pd.DataFrame()
# Sort by confidence and return top_k recommendations
recommendations = matching_rules.nlargest(top_k, "confidence")
return recommendations[
["antecedents", "consequents", "confidence", "lift", "support"]
]
def __str__(self) -> str:
"""Return string representation of the AssociationRules object."""
return (
f"AssociationRules(min_support={self.min_support}, "
f"min_confidence={self.min_confidence}, "
f"rating_threshold={self.rating_threshold})"
)
def __repr__(self) -> str:
"""Return detailed string representation of the AssociationRules object."""
return self.__str__()