from typing import Callable, List, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from beexai.evaluate.metrics.metrics import CustomMetric
from beexai.utils.time_seed import time_function
[docs]
class Comprehensiveness(CustomMetric):
"""Implementation of the comprehensiveness metric.
Computes the comprehensiveness of the model by removing
the most important features one by one and
computing the difference in prediction with the original input.
References:
- `ERASER: A Benchmark to Evaluate Rationalized NLP Models
<https://arxiv.org/abs/1911.03429>`
Attributes:
model (Callable): model to explain
task (str): task to perform
device (str): device to use
Methods:
get_comp: computes the comprehensiveness of the model
get_mr_list: computes the comprehensiveness of the model
for different ratios of features removed
"""
[docs]
def get_comp(
self,
x_in: torch.Tensor,
feature_by_importance: torch.Tensor,
removal_ratio: Union[float, list] = 0.3,
label: Optional[Union[int, list, np.ndarray, torch.Tensor]] = None,
baseline: str = "zero",
) -> float:
"""Computes the comprehensiveness of the model.
Args:
x_in (torch.Tensor): input data
feature_by_importance (torch.Tensor): indexes of
most important features in descending order
removal_ratio (float, list): ratio of features to remove.
If a list is provided, the function will compute the
average comprehensiveness over the list of ratios.
label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest.
Defaults to None. A list of labels for each instance can be provided.
baseline (str, optional): baseline to use. Defaults to "zero"
Returns:
float: comprehensiveness score
"""
all_comp = 0
if isinstance(removal_ratio, float):
ratios = [removal_ratio]
else:
ratios = removal_ratio
for rm_ratio in ratios:
self.check_shape(x_in, feature_by_importance)
pred_allf, max_arg = self.select_output(x_in, label)
n_feats = x_in.shape[1]
n_feats_rm = int(n_feats * rm_ratio)
input_rmf = self.choose_baseline(x_in, baseline, device=self.device)
indexes_to_keep = feature_by_importance[:, n_feats_rm:]
r_ind = torch.arange(len(indexes_to_keep))[:, None]
c_ind = indexes_to_keep
input_rmf[r_ind, c_ind] = x_in[r_ind, c_ind]
# **PREVIOUS IMPLEMENTATION LESS EFFICIENT BUT MORE READABLE**
# for i in range(feat_imp.shape[0]):
# input_rmf[i] = baseline_values[i]
# for j in range(n_feats_rm,feat_imp.shape[1]):
# index = feat_imp[i][j]
# input_rmf[i][index] = X[i][index]
if label is not None:
pred_rmf = self.get_predlb(input_rmf, label)
else:
pred_rmf = self.get_predlb(input_rmf, max_arg)
diff = pred_allf - pred_rmf
if self.task == "regression":
diff = torch.abs(diff)
comp = torch.mean(diff, axis=0).item()
all_comp += comp
return all_comp / len(ratios)
[docs]
def get_mr_list(
self,
n_features: int,
x_test: torch.Tensor,
orders: torch.Tensor,
n_plot: int,
baseline: str = "zero",
label: Optional[Union[int, list, np.ndarray, torch.Tensor]] = None,
) -> List[float]:
"""Compute the comprehensiveness of the model for different
ratios of features removed.
Args:
n_features (int): number of features
x_test (torch.Tensor): test data
orders (torch.Tensor): indexes of most important features
in descending order
n_plot (int): number of points to plot
baseline (str, optional): baseline to use. Defaults to "zero".
label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest.
Returns:
list: list of comprehensiveness scores
"""
assert 0 <= n_plot, "n_plot must be positive"
comp_list = []
for i in range(n_plot):
rm_r = i / n_features
comp = self.get_comp(
x_in=x_test,
feature_by_importance=orders,
removal_ratio=rm_r,
label=label,
baseline=baseline,
)
comp_list.append(comp)
return comp_list
[docs]
def plot_comp(
n_features: int,
comp_list: List[float],
rand_comp_list: List[float],
randmodel_comp_list: List[float],
n_plot: int,
same_fig: bool = False,
save_path: Optional[str] = None,
) -> None:
"""Plot the comprehensiveness of the base model, the random
explainer and the random model for different ratios of features
removed.
Args:
n_features (int): number of features
comp_list (list): comprehensiveness list for the base model
rand_comp_list (list): comprehensiveness list for
the random explainer
randmodel_comp_list (list): comprehensiveness list for
the random model
n_plot (int): number of points to plot
same_fig (bool, optional): whether to plot on the same figure.
Defaults to False.
save_path (str, optional): path to save the plot. Defaults to None.
"""
assert 0 <= n_plot, "n_plot must be positive"
x_axis = np.arange(n_plot) / n_features
if not same_fig:
fig = plt.figure()
subplot_size = 110
if rand_comp_list:
subplot_size += 100
if randmodel_comp_list:
subplot_size += 100
ax1 = fig.add_subplot(subplot_size + 1)
ax1.plot(x_axis, comp_list, label="Comprehensiveness")
ax1.legend()
if rand_comp_list:
ax2 = fig.add_subplot(subplot_size + 2)
ax2.plot(x_axis, rand_comp_list, label="Random Comprehensiveness")
ax2.legend()
if randmodel_comp_list:
ax3 = fig.add_subplot(subplot_size + 3)
ax3.plot(
x_axis, randmodel_comp_list, label="Random Model Comprehensiveness"
)
ax3.legend()
plt.xlabel("Ratio of features removed")
plt.ylabel("Mean absolute difference")
if save_path is not None:
plt.savefig(save_path)
plt.show()
else:
plt.plot(x_axis, comp_list, label="Comprehensiveness")
plt.plot(x_axis, rand_comp_list, label="Random Comprehensiveness")
plt.plot(x_axis, randmodel_comp_list, label="Random Model Comprehensiveness")
plt.xlabel("Ratio of features removed")
plt.ylabel("Mean probability difference")
plt.legend()
if save_path is not None:
plt.savefig(save_path)
plt.show()
[docs]
@time_function
def compute_comp(
model: Callable,
rand_model: Callable,
task: str,
x_test: torch.Tensor,
ord_feat: torch.Tensor,
rand_ord_feat: torch.Tensor,
randmodel_ord_feat: torch.Tensor,
n_plot: int,
removal_ratio: Union[float, list],
label: Union[int, list, np.ndarray, torch.Tensor],
metrics: dict,
baseline: str = "zero",
print_plot: bool = False,
device: str = "cpu",
) -> dict:
"""Computes the comprehensiveness of the base model, the random
explainer and the random model.
Args:
model (Callable): model to explain
rand_model (Callable): reference model
task (str): task to perform
x_test (torch.Tensor): test data
ord_feat (torch.Tensor): indexes of most important features
in descending order for the base model
rand_ord_feat (torch.Tensor): indexes of most important features
in descending order for the random explainer
randmodel_ord_feat (torch.Tensor): indexes of most important features
in descending order for the random model
n_plot (int): number of points to plot
removal_ratio (float, list): ratio of features to remove. If a list is
provided, the function will compute the average comprehensiveness
over the list of ratios.
label (Union[int, list, np.ndarray, torch.Tensor]): label(s) of interest
metrics (dict): dictionary of metrics
baseline (str, optional): baseline to use. Defaults to "zero".
print_plot (bool, optional): whether to display the plot.
Defaults to False.
device (str, optional): device to use. Defaults to "cpu".
Returns:
dict: dict of metrics
"""
n_features = x_test.shape[1]
comp = Comprehensiveness(model, task, device)
use_ref = rand_model is not None
use_random = rand_ord_feat is not None
if use_ref:
randmodel_comp = Comprehensiveness(rand_model, task, device)
else:
randmodel_comp = None
if print_plot:
comp_list = comp.get_mr_list(
n_features, x_test, ord_feat, n_plot, baseline, label
)
if use_random:
rand_comp_list = comp.get_mr_list(
n_features, x_test, rand_ord_feat, n_plot, baseline, label
)
else:
rand_comp_list = []
if use_ref:
randmodel_comp_list = randmodel_comp.get_mr_list(
n_features, x_test, randmodel_ord_feat, n_plot, baseline, label
)
else:
randmodel_comp_list = []
plot_comp(n_features, comp_list, rand_comp_list, randmodel_comp_list, n_plot)
comp_score = comp.get_comp(x_test, ord_feat, removal_ratio, label, baseline)
metrics["Comprehensiveness"]["original"] = comp_score
if use_random:
rand_comp_score = comp.get_comp(
x_test, rand_ord_feat, removal_ratio, label, baseline
)
metrics["Comprehensiveness"]["random"] = rand_comp_score
if use_ref:
randmodel_comp_score = randmodel_comp.get_comp(
x_test, randmodel_ord_feat, removal_ratio, label, baseline
)
metrics["Comprehensiveness"]["random model"] = randmodel_comp_score
return metrics