Files
py-grex/pygrex/explain/individual/post_hoc_association_rules.py
2026-05-22 10:02:10 +02:00

80 lines
2.6 KiB
Python

from typing import Any, Dict
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import apriori, association_rules
import pandas as pd
from .explainer import Explainer
class ARPostHocExplainer(Explainer):
def __init__(
self,
model,
recommendations,
data,
min_support=0.1,
max_len=2,
metric="lift",
min_threshold=0.1,
min_confidence=0.1,
min_lift=0.1,
):
super(ARPostHocExplainer, self).__init__(model, recommendations, data)
self.AR = None
self.min_support = min_support
self.max_len = max_len
self.metric = metric
self.min_threshold = min_threshold
self.min_confidence = min_confidence
self.min_lift = min_lift
self.rules: pd.DataFrame | None = None
def get_rules_for_getting(self, item_id: int) -> pd.DataFrame:
if self.rules is None:
self.compute_association_rules()
if self.rules is not None:
return self.rules[self.rules.consequents == item_id]
return pd.DataFrame()
def compute_association_rules(self):
item_sets = [
[item for item in self.dataset[self.dataset.userId == user].itemId]
for user in self.dataset.userId.unique()
]
te = TransactionEncoder()
te_ary = te.fit(item_sets).transform(item_sets)
# The te_ary object is a NumPy array, which is a valid input for a DataFrame.
# Pylance may raise a false positive here due to incomplete type stubs for mlxtend.
df = pd.DataFrame(te_ary.astype(bool), columns=te.columns_) # type: ignore
frequent_itemsets = apriori(
df, min_support=self.min_support, use_colnames=True, max_len=self.max_len
)
rules = association_rules(
frequent_itemsets, metric="lift", min_threshold=self.min_threshold
)
rules = rules[
(rules["confidence"] > self.min_confidence)
& (rules["lift"] > self.min_lift)
]
rules["consequents"] = rules["consequents"].apply(lambda x: list(x)[0])
rules["antecedents"] = rules["antecedents"].apply(lambda x: list(x)[0])
self.rules = rules[["consequents", "antecedents", "confidence"]]
def explain_recommendation_to_user(
self, user_id: int, item_id: int
) -> Dict[str, Any]:
user_ratings = self.get_user_items(user_id)
rules = self.get_rules_for_getting(item_id)
explanations = rules[rules.antecedents.isin(user_ratings)]
return {"antecedents": set(explanations.antecedents)}