69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
from typing import Dict, Any
|
|
|
|
from pygrex.utils import calculate_gild_for_explanations
|
|
|
|
|
|
class ExplanationEvaluator:
|
|
"""
|
|
A unified evaluator for different explanation methods.
|
|
|
|
This class takes the results generated by an explainer and calculates
|
|
a standard set of quality metrics, such as Fidelity and Diversity (GILD).
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes the ExplanationEvaluator."""
|
|
# This class is stateless, so __init__ is simple.
|
|
pass
|
|
|
|
def evaluate(
|
|
self, explanation_results: Dict[str, Any], explainer_type: str
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Calculates all relevant metrics for a given explanation result.
|
|
|
|
Args:
|
|
explanation_results: The dictionary returned by an explainer's
|
|
`find_explanation` method.
|
|
explainer_type: A string identifier for the explainer used
|
|
(e.g., "LORE4Groups", "EXPGRS").
|
|
|
|
Returns:
|
|
A dictionary containing the calculated metric scores.
|
|
"""
|
|
if not explanation_results:
|
|
return {"fidelity": 0.0, "gild": 0.0}
|
|
|
|
fidelity = self._calculate_fidelity(explanation_results)
|
|
gild = self._calculate_gild(explanation_results, explainer_type)
|
|
|
|
return {"fidelity": fidelity, "gild": gild}
|
|
|
|
def _calculate_fidelity(self, explanation_results: Dict[str, Any]) -> float:
|
|
"""
|
|
Extracts the fidelity score from the explanation results.
|
|
|
|
Fidelity is computed by the explainer itself, as it's the ratio of
|
|
items it was able to explain. This method standardizes its retrieval.
|
|
"""
|
|
return explanation_results.get("fidelity", 0.0)
|
|
|
|
def _calculate_gild(
|
|
self, explanation_results: Dict[str, Any], explainer_type: str
|
|
) -> float:
|
|
"""
|
|
Calculates the Gaussian Inter-List Diversity (GILD) of the explanations.
|
|
|
|
This is a wrapper around the utility function that handles the details.
|
|
It uses the 'details' part of the explanation results.
|
|
"""
|
|
explanation_details = explanation_results.get("details", {})
|
|
if not explanation_details:
|
|
return 0.0
|
|
|
|
# The GILD function is now called from a central, logical place.
|
|
gild_score = calculate_gild_for_explanations(
|
|
explanation_details, explainer_type
|
|
)
|
|
return gild_score
|