Source code for beexai.evaluate.metrics.metrics

from typing import Callable, Optional, Tuple, Union

import numpy as np
import torch

from beexai.utils.convert import convert_to_tensor


[docs] class CustomMetric: """Base class for all metrics. Attributes: model (Callable): model to explain task (str): task to perform device (str): device to use Methods: select_output: select the output of the model for a given label get_predlb: get the prediction of the model for a given label choose_baseline: choose a baseline for removal based metrics """ def __init__(self, model: Callable, task: str, device: str = "cpu"): assert task in [ "classification", "regression", ], f"task must be in ['classification', 'regression'], found {task}" self.model = model self.task = task self.device = device
[docs] def check_shape(self, x_in: torch.Tensor, attributions: torch.Tensor) -> None: """Check the shape of the attributions. Args: x_in (torch.Tensor): input data attributions (torch.Tensor): attributions """ assert ( x_in.ndim == 2 and attributions.ndim == 2 ), f"""Input tensor and attributions tensor must be 2-dimensional. Found dimensions {x_in.ndim} and {attributions.ndim}""" assert ( x_in.shape == attributions.shape ), f"""Input tensor and attributions tensor must have the same shape. Found shapes {x_in.shape} and {attributions.shape}"""
[docs] def select_output( self, x_in: torch.Tensor, label: Optional[Union[int, list, np.ndarray, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Select the output of the model for a given label. If label is None, return the output of the model and the max argument of the probabilities (for classification). Args: x_in (torch.Tensor): input data label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest. Defaults to None. A list of labels for each instance can be provided. Returns: torch.Tensor: output of the model torch.Tensor: label of the output """ if self.task == "classification": with torch.no_grad(): pred = self.model.predict_proba(x_in) if isinstance(label, int): res = pred[:, label] max_arg = label * torch.ones(pred.shape[0], dtype=int) elif isinstance(label, (list, np.ndarray, torch.Tensor)): res = pred[torch.arange(pred.shape[0]), label] max_arg = label else: res = torch.max(pred, axis=1).values max_arg = torch.argmax(pred, dim=1) else: with torch.no_grad(): res = self.model.predict(x_in).reshape(-1) max_arg = None res = convert_to_tensor(res, self.device) return res, max_arg
[docs] def get_predlb( self, x_in: torch.Tensor, label: Optional[Union[int, list, np.ndarray, torch.Tensor]] = None, ) -> torch.Tensor: """Get the prediction of the model for a given label. If label is None, return the prediction of the model for the max probability (for classification). Args: x_in (torch.Tensor): input data label (int, list, np.ndarray, torch.Tensor, optional): label(s) of interest. Defaults to None. A list of labels for each instance can be provided. Returns: torch.Tensor: prediction of the model """ if isinstance(label, int): indexes = torch.ones(x_in.shape[0], dtype=int) * label elif isinstance(label, (list, np.ndarray, torch.Tensor)): indexes = label else: _, indexes = self.select_output(x_in) if self.task == "classification": with torch.no_grad(): pred = self.model.predict_proba(x_in) pred = convert_to_tensor(pred, self.device) res = torch.zeros(pred.shape[0], device=self.device) for i in range(pred.shape[0]): res[i] = pred[i][indexes[i]] else: with torch.no_grad(): res = self.model.predict(x_in).reshape(-1) res = convert_to_tensor(res, self.device) return res
[docs] def choose_baseline( self, x_in: torch.Tensor, baseline: str, n_samples: int = 100, device: str = "cpu", ) -> torch.Tensor: """Choose a baseline for removal based metrics. Args: x_in (torch.Tensor): input data baseline (str): baseline to use n_samples (int, optional): number of samples for multiple baselines. Defaults to 10. device (str, optional): device to use. Defaults to "cpu". Returns: torch.Tensor: baseline """ authorized_baseline = [ "mean", "median", "zero", "multiple", "normal", "uniform", "min", "max", ] assert ( baseline in authorized_baseline ), f"baseline must be in {authorized_baseline}" res = torch.ones(x_in.shape, device=self.device) if baseline == "mean": res = torch.mul(torch.mean(x_in, axis=0), res) elif baseline == "median": res = torch.mul(torch.median(x_in, axis=0)[0], res) elif baseline == "zero": res = res * 0 elif baseline == "multiple": samples = torch.randperm(x_in.shape[0])[:n_samples] res = torch.mean(x_in[samples], axis=0) res = torch.mul(res, torch.ones(x_in.shape, device=self.device)) elif baseline == "uniform": for i in range(x_in.shape[1]): max_val = torch.max(x_in[:, i]) min_val = torch.min(x_in[:, i]) sample = torch.rand(x_in.shape[0]) res[:, i] = sample * (max_val - min_val) + min_val elif baseline == "normal": for i in range(x_in.shape[1]): mean = torch.mean(x_in[:, i]) std = torch.std(x_in[:, i]) res[:, i] = torch.normal(mean, std, size=(x_in.shape[0],)) elif baseline == "min": for i in range(x_in.shape[1]): res[:, i] = torch.min(x_in[:, i]) elif baseline == "max": for i in range(x_in.shape[1]): res[:, i] = torch.max(x_in[:, i]) res = res.to(device) return res