Source code for beexai.evaluate.metrics.get_results

from collections import defaultdict
from typing import Callable, List, Optional, Union

import numpy as np
import pandas as pd
import torch

from beexai.evaluate.metrics.auc_tp import compute_auc
from beexai.evaluate.metrics.complexity import compute_complex
from beexai.evaluate.metrics.comprehensiveness import compute_comp
from beexai.evaluate.metrics.faithfulnesscorr import compute_faith_corr
from beexai.evaluate.metrics.infidelity import compute_inf
from beexai.evaluate.metrics.monotonicity import compute_mono
from beexai.evaluate.metrics.sensitivity import compute_sens
from beexai.evaluate.metrics.sparseness import compute_spar
from beexai.evaluate.metrics.sufficiency import compute_suff
from beexai.explanation.explaining import GeneralExplainer
from beexai.utils.path import get_path

pd.set_option("display.max_columns", None)


[docs] def get_all_metrics( x_test: torch.Tensor, label: Optional[Union[int, list, np.ndarray, torch.Tensor]], model: Callable, exp: GeneralExplainer, ref_model: Optional[Callable] = None, refmodel_exp: Optional[GeneralExplainer] = None, baseline: str = "zero", auc_metric: str = "mse", subratio_faith: float = 0.2, comp_ratio: Union[float, list] = 0.3, suff_ratio: Union[float, list] = 0.3, inf_std: Optional[Union[float, torch.Tensor, np.ndarray]] = None, save_path: Optional[str] = None, metrics_to_get: List[str] = [ "FaithCorr", "Infidelity", "Sensitivity", "Comprehensiveness", "Sufficiency", "Monotonicity", "AUC_TP", "Complexity", "Sparseness", ], print_plot: bool = False, attributions: Optional[torch.Tensor] = None, attributions_ref: Optional[torch.Tensor] = None, device: str = "cpu", use_ref: bool = False, use_random: bool = False, radius: Optional[float] = None, ) -> pd.DataFrame: """Compute all metrics for a given label. Args: x_test (torch.Tensor): test data label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest. Defaults to None. A list of labels can be provided, one for each instance. model (object): model to explain exp (object): explainer for the model to explain ref_model (object): reference model (random model) refmodel_exp (object): explainer for the reference model baseline (str, optional): baseline to use for the metrics. Defaults to "zero". Must be one of ["mean", "median", "zero", "multiple", "normal", "uniform"]. auc_metric (str, optional): performance metric to use for the AUC_TP metric. Defaults to "mse". Must be one of ["mse","accuracy"]. subratio_faith (float, optional): ratio of features to use for the faithfulness metric. Defaults to 0.2. comp_ratio (float, list, optional): ratio of features to remove for the comprehensiveness metric. Defaults to 0.3. suff_ratio (float, list, optional): ratio of features to keep for the sufficiency metric. Defaults to 0.3. inf_std (float, optional): std of the noise to add for the infidelity metric. Defaults to 0.003. save_path (str, optional): path to save the metrics. Defaults to None. metrics_to_get (list, optional): list of metrics to compute. Defaults to ["FaithCorr","Infidelity","Sensitivity", "Comprehensiveness","Sufficiency","Monotonicity","AUC_TP", "Complexity","Sparseness"]. print_plot (bool, optional): whether to plot the figures and print the metrics. Defaults to False. attributions (torch.Tensor, optional): precomputed attributions for the model to explain. Defaults to None. attributions_ref (torch.Tensor, optional): precomputed attributions for the reference model. Defaults to None. device (str, optional): device to use. Defaults to "cpu". use_ref (bool, optional): whether to use the reference model for the metrics. Defaults to True. use_random (bool, optional): whether to use random attributions for the metrics. Defaults to True. radius (float, optional): radius for the sensitivity metric. Defaults to None. Returns: pd.DataFrame: dataframe containing the metrics """ for metric in metrics_to_get: assert metric in [ "FaithCorr", "Infidelity", "Sensitivity", "Comprehensiveness", "Sufficiency", "Monotonicity", "AUC_TP", "Complexity", "Sparseness", ], f"""Metric {metric} not recognized. Choose from: ["FaithCorr","Infidelity","Sensitivity","Comprehensiveness", "Sufficiency","Monotonicity","AUC_TP","Complexity","Sparseness"]""" if isinstance(x_test, pd.DataFrame): x_test = torch.tensor(x_test.values, dtype=torch.float32, device=device) if isinstance(x_test, np.ndarray): x_test = torch.tensor(x_test, dtype=torch.float32, device=device) if radius is None: radius = torch.mean( torch.stack( [ torch.abs(x_test[i] - x_test[j]) for i in range(x_test.shape[0]) for j in range(i + 1, x_test.shape[0]) ] ), axis=0, ) if inf_std is None: inf_std = torch.std(x_test, dim=0) metrics = defaultdict(dict) task = exp.task use_abs = task == "regression" subsize_faith = int(subratio_faith * x_test.shape[1]) if attributions is not None: attributions = attributions.to(device) else: attributions = exp.explain(x_test, label=label, absolute=use_abs) orders = exp.feature_order(attributions) if use_ref: if attributions_ref is not None: randmodel_attributions = attributions_ref.to(device) else: randmodel_attributions = refmodel_exp.explain( x_test, label=label, absolute=use_abs ) randmodel_orders = refmodel_exp.feature_order(randmodel_attributions) else: ref_model = None randmodel_attributions = None randmodel_orders = None refmodel_exp = None if use_random: lb, ub = torch.min(attributions), torch.max(attributions) rand_attrib = torch.rand(attributions.shape, device=device) * (ub - lb) + lb rand_orders = exp.feature_order(rand_attrib) else: rand_attrib = None rand_orders = None if "FaithCorr" in metrics_to_get: metrics = compute_faith_corr( model, ref_model, task, subsize_faith, x_test, attributions, rand_attrib, randmodel_attributions, label, metrics, device, ) if "Infidelity" in metrics_to_get: metrics = compute_inf( model, ref_model, task, x_test, attributions, rand_attrib, randmodel_attributions, label, metrics, device, inf_std, ) if "Sensitivity" in metrics_to_get: metrics = compute_sens( model, ref_model, task, x_test, label, metrics, exp, refmodel_exp, device, use_random, attributions, rand_attrib, randmodel_attributions, radius, ) n_plot = x_test.shape[1] + 1 if "Comprehensiveness" in metrics_to_get: metrics = compute_comp( model, ref_model, task, x_test, orders, rand_orders, randmodel_orders, n_plot, comp_ratio, label, metrics, baseline, print_plot, device, ) if "Sufficiency" in metrics_to_get: metrics = compute_suff( model, ref_model, task, x_test, orders, rand_orders, randmodel_orders, n_plot, suff_ratio, label, metrics, baseline, print_plot, device, ) if "Monotonicity" in metrics_to_get: metrics = compute_mono( model, ref_model, task, x_test, orders, rand_orders, randmodel_orders, label, metrics, baseline, device, ) if "AUC_TP" in metrics_to_get: metrics = compute_auc( model, ref_model, task, x_test, orders, rand_orders, randmodel_orders, metrics, baseline, auc_metric, print_plot, device, ) if "Complexity" in metrics_to_get: metrics = compute_complex( model, ref_model, task, attributions, rand_attrib, randmodel_attributions, metrics, device, ) if "Sparseness" in metrics_to_get: metrics = compute_spar( model, ref_model, task, attributions, rand_attrib, randmodel_attributions, metrics, device, ) comparators = ["Original"] if use_random: comparators.append("Random") if use_ref: comparators.append("Random Model") if len(comparators) > 1: cols = pd.MultiIndex.from_product([metrics_to_get, comparators]) else: cols = metrics_to_get metrics = [list(metrics[metric].values()) for metric in metrics_to_get] metrics = np.array(metrics) metrics = np.reshape(metrics, (1, -1)) df = pd.DataFrame(metrics, columns=cols) if print_plot: print("-------------------") print(df) if save_path is not None: save_path = get_path(save_path) df.to_csv(save_path) return df