from typing import Callable, Optional, Union
import numpy as np
import torch
from beexai.evaluate.metrics.metrics import CustomMetric
from beexai.utils.time_seed import time_function
[docs]
class Infidelity(CustomMetric):
"""Implementation of the infidelity metric.
Computes the infidelity of the model by adding significant
noise to the input and compute the mean-squared error between
the pertubation applicated to the attribution and
the difference in prediction original input and perturbed input.
References:
- `On the (In)fidelity and Sensitivity for Explanations
<https://arxiv.org/abs/1901.09392>`
Attributes:
model (callable): model to explain
task (str): task to perform
device (str): device to use
std (float): std of the noise
Methods:
get_inf: computes the infidelity of the model
"""
def __init__(
self, model: Callable, task: str, std: float = 0.003, device: str = "cpu"
):
super().__init__(model, task, device)
self.std = std
def __get_noises__(self, x_in: torch.Tensor, k: int = 5):
"""Generate k noises from a normal distribution with
mean 0 and std self.std."""
n_shape = (k, x_in.shape[0], x_in.shape[1])
stds = torch.tensor(self.std, device=self.device).float()
if stds.ndim == 0:
stds = [stds]
all_noises = torch.concatenate(
[torch.normal(0, std, n_shape, device=self.device) for std in stds]
)
return x_in - all_noises
[docs]
def get_inf(
self,
x_in: torch.Tensor,
attributions: torch.Tensor,
label: Optional[Union[int, list, torch.Tensor, np.ndarray]] = None,
) -> float:
"""Computes the infidelity of the model.
Args:
x_in (torch.Tensor): input data
attributions (torch.Tensor): attributions for each instance
label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest.
Defaults to None. A list of labels for each instance can be provided.
Returns:
float: infidelity score
"""
self.check_shape(x_in, attributions)
pred, max_arg = self.select_output(x_in, label=label)
pert = self.__get_noises__(x_in)
infs = torch.zeros((x_in.shape[0], len(pert)), device=self.device)
for j, noise in enumerate(pert):
pert_in = x_in - noise
pert_pred = self.get_predlb(pert_in, max_arg)
diff = (pred - pert_pred).squeeze()
if self.task == "regression":
diff = torch.abs(diff)
perturbed_term = torch.einsum("ij, ji->i", noise, attributions.T)
inf = (perturbed_term - diff) ** 2
infs[:, j] = inf
# **PREVIOUS IMPLEMENTATION LESS EFFICIENT BUT MORE READABLE**
# for i in range(X.shape[0]):
# infidelities = np.zeros(len(pert))
# for j,perturbation in enumerate(pert):
# repeated_perturbation = perturbation[i]
# perturbed_input = (X[i] - repeated_perturbation).unsqueeze(0)
# perturbed_pred = self.get_predlb(perturbed_input,max_arg)
# diff = pred[i] - perturbed_pred
# attribute = attribution_tensor[i]
# perturbed_term = torch.matmul(
# perturbation[i],attribute).detach().cpu().numpy()
# infidelity = (perturbed_term - diff)**2
# infidelities[j] = infidelity.item()
# infidelity = np.mean(infidelities,axis=0)
# mean_inf[i] = infidelity
mean_inf = torch.mean(infs, axis=1)
return torch.mean(mean_inf, axis=0).item()
[docs]
@time_function
def compute_inf(
model: Callable,
rand_model: Callable,
task: str,
x_test: torch.Tensor,
attributions: torch.Tensor,
rand_attrib: torch.Tensor,
randmodel_attributions: torch.Tensor,
label: Union[int, list, torch.Tensor, np.ndarray],
metrics: dict,
device: str = "cpu",
inf_std: float = 0.003,
) -> dict:
"""Compute the infidelity metric.
Args:
model (callable): base model
rand_model (callable): reference model (random model)
task (str): task of the model
x_test (torch.Tensor): test data
attributions (torch.Tensor): attributions for base model
rand_attrib (torch.Tensor): random attributions
randmodel_attributions (torch.Tensor): attributions for reference model
label (Union[int, list, np.ndarray, torch.Tensor]): label(s) of interest
metrics (dict): dictionary of metrics
device (str, optional): device to use. Defaults to "cpu"
inf_std (float, optional): std of the noise.
Defaults to 0.003.
Returns:
dict: dict of metrics
"""
inf = Infidelity(model, task, inf_std, device)
inf_score = inf.get_inf(x_test, attributions, label)
metrics["Infidelity"]["original"] = inf_score
if rand_attrib is not None:
rand_inf_score = inf.get_inf(x_test, rand_attrib, label)
metrics["Infidelity"]["random"] = rand_inf_score
if rand_model is not None:
randmodel_inf = Infidelity(rand_model, task, inf_std, device)
randmodel_inf_score = randmodel_inf.get_inf(
x_test, randmodel_attributions, label
)
metrics["Infidelity"]["random_model"] = randmodel_inf_score
return metrics