Source code for beexai.evaluate.metrics.faithfulnesscorr

from typing import Callable, Optional, Union

import numpy as np
import torch

from beexai.evaluate.metrics.metrics import CustomMetric
from beexai.utils.time_seed import time_function


[docs] class FaithfulnessCorrelation(CustomMetric): """Implementation of the faithfulness correlation metric. Computes the faithfulness of the model by removing a fixed number of features and compute the Pearson correlation between the summed attributions and the difference in prediction with the original input. References: - `Synthetic Benchmarks for Scientific Research in Explainable Machine Learning <https://arxiv.org/abs/2106.12543>` Attributes: model (callable): model to explain task (str): task to perform device (str): device to use Methods: get_faithfulness: computes the faithfulness of the model """
[docs] def get_faithfulness( self, x_in: torch.Tensor, attributions: torch.Tensor, n_features_subset: int, label: Optional[Union[int, list, np.ndarray, torch.Tensor]] = None, n_repeats: int = 20, baseline: str = "zero", ) -> float: """Computes the faithfulness of the model. Args: x_in (torch.Tensor): input data attributions (torch.Tensor): attributions for each instance n_features_subset (int): number of features to remove label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest. Defaults to None. A list of labels for each instance can be provided. n_repeats (int, optional): number of times to repeat the sampling. Defaults to 20. baseline (str, optional): baseline to use. Defaults to "zero". Returns: float: faithfulness score """ self.check_shape(x_in, attributions) n_samples, n_features = x_in.shape n_features_subset = max(n_features_subset, 1) assert n_features_subset > 0, "n_features_subset must be > 0" assert n_repeats > 0, "n_repeats must be > 0" pred, max_arg = self.select_output(x_in, label=label) feature_subsets = torch.randint(n_features, size=(n_repeats, n_features_subset)) baseline_values = self.choose_baseline(x_in, baseline, device=self.device) deltas = torch.zeros((n_repeats, n_samples), device=self.device) sums = torch.zeros((n_repeats, n_samples), device=self.device) for j in range(n_repeats): feature_subset = feature_subsets[j] mask = torch.ones(x_in.shape, device=self.device) r_ind = torch.arange(n_samples)[:, None] c_ind = feature_subset mask[r_ind, c_ind] = baseline_values[r_ind, c_ind] x_pert = torch.multiply(x_in, mask) pred_new = self.get_predlb(x_pert, max_arg) deltas[j] = (pred - pred_new).squeeze() if self.task == "regression": deltas[j] = torch.abs(deltas[j]) sums[j] = torch.mean(attributions[r_ind, c_ind], axis=1).squeeze() deltas_flat = deltas.flatten().detach().cpu().numpy() sums_flat = sums.flatten().detach().cpu().numpy() faithfulness = np.corrcoef(deltas_flat, sums_flat)[0, 1] # **PREVIOUS IMPLEMENTATION LESS EFFICIENT BUT MORE READABLE** # for i in range(X.shape[0]): # deltas = np.zeros(n_repeats) # sums = np.zeros(n_repeats) # for j in range(n_repeats): # feature_subset = feature_subsets[j] # mask = torch.ones(X.shape[1:],device=self.device) # mask[feature_subset] = baseline_values[i][feature_subset] # x_pert = torch.multiply(X[i],mask).unsqueeze(0) # pred_new = self.get_predlb(x_pert,max_arg) # deltas[j] = abs(pred[i]-pred_new).squeeze() # sums[j] = np.mean(attributions[i][feature_subset]).squeeze() # faithfulness = np.corrcoef(deltas,sums)[0,1] # if np.isnan(faithfulness) or np.isinf(faithfulness): # faithfulness = 0 # faithfulnesses[i] = faithfulness # faithfulness_corr = np.mean(faithfulnesses) return faithfulness.item()
[docs] @time_function def compute_faith_corr( model: Callable, rand_model: Callable, task: str, subset_size_faithfulness: int, x_test: torch.Tensor, attributions: torch.Tensor, rand_attrib: torch.Tensor, randmodel_attributions: torch.Tensor, label: Union[int, list, np.ndarray, torch.Tensor], metrics: dict, device: str = "cpu", ) -> dict: """Compute the faithfulness correlation metric. Args: model (callable): base model rand_model (callable): reference model (random model) task (str): task of the model subset_size_faithfulness (int): number of features to remove x_test (torch.Tensor): test data attributions (torch.Tensor): attributions for each instance for the base model rand_attrib (torch.Tensor): attributions for each instance for the random explainer randmodel_attributions (torch.Tensor): attributions for each instance for the random model label (int, list, np.ndarray, torch.Tensor): label(s) of interest metrics (dict): dictionary of metrics device (str, optional): device to use. Defaults to "cpu". Returns: dict: dict of metrics """ n_features = x_test.shape[1] faith = FaithfulnessCorrelation(model, task, device) features_subset_size = min(subset_size_faithfulness, n_features) faith_score = faith.get_faithfulness( x_test, attributions, features_subset_size, label ) metrics["FaithCorr"]["original"] = faith_score if rand_attrib is not None: rand_faith_score = faith.get_faithfulness( x_test, rand_attrib, features_subset_size, label ) metrics["FaithCorr"]["random"] = rand_faith_score if rand_model is not None: randmodel_faith = FaithfulnessCorrelation(rand_model, task, device) randmodel_faith_score = randmodel_faith.get_faithfulness( x_test, randmodel_attributions, features_subset_size, label ) metrics["FaithCorr"]["random model"] = randmodel_faith_score return metrics