Files
py-grex/pygrex/evaluator/explainer_evaluator.py
2026-05-22 10:02:10 +02:00

69 lines
2.4 KiB
Python

from typing import Dict, Any
from pygrex.utils import calculate_gild_for_explanations
class ExplanationEvaluator:
"""
A unified evaluator for different explanation methods.
This class takes the results generated by an explainer and calculates
a standard set of quality metrics, such as Fidelity and Diversity (GILD).
"""
def __init__(self):
"""Initializes the ExplanationEvaluator."""
# This class is stateless, so __init__ is simple.
pass
def evaluate(
self, explanation_results: Dict[str, Any], explainer_type: str
) -> Dict[str, float]:
"""
Calculates all relevant metrics for a given explanation result.
Args:
explanation_results: The dictionary returned by an explainer's
`find_explanation` method.
explainer_type: A string identifier for the explainer used
(e.g., "LORE4Groups", "EXPGRS").
Returns:
A dictionary containing the calculated metric scores.
"""
if not explanation_results:
return {"fidelity": 0.0, "gild": 0.0}
fidelity = self._calculate_fidelity(explanation_results)
gild = self._calculate_gild(explanation_results, explainer_type)
return {"fidelity": fidelity, "gild": gild}
def _calculate_fidelity(self, explanation_results: Dict[str, Any]) -> float:
"""
Extracts the fidelity score from the explanation results.
Fidelity is computed by the explainer itself, as it's the ratio of
items it was able to explain. This method standardizes its retrieval.
"""
return explanation_results.get("fidelity", 0.0)
def _calculate_gild(
self, explanation_results: Dict[str, Any], explainer_type: str
) -> float:
"""
Calculates the Gaussian Inter-List Diversity (GILD) of the explanations.
This is a wrapper around the utility function that handles the details.
It uses the 'details' part of the explanation results.
"""
explanation_details = explanation_results.get("details", {})
if not explanation_details:
return 0.0
# The GILD function is now called from a central, logical place.
gild_score = calculate_gild_for_explanations(
explanation_details, explainer_type
)
return gild_score