Source code for beexai.evaluate.metrics.sparseness

from typing import Callable

import torch

from beexai.evaluate.metrics.metrics import CustomMetric
from beexai.utils.time_seed import time_function


[docs] class Sparseness(CustomMetric): """Implementation of the sparseness metric. Computes the sparseness of the model based on the Gini index. References: - `Synthetic Benchmarks for Scientific Research in Explainable Machine Learning <https://arxiv.org/abs/2106.12543>` Attributes: model (callable): model to explain task (str): task to perform device (str): device to use Methods: get_sparsity: computes the sparseness of the model """
[docs] def get_sparsity(self, attribution: torch.Tensor) -> torch.Tensor: """Computes the sparseness of the model. Args: attribution (torch.Tensor): attributions for each instance Returns: torch.Tensor: array of sparseness scores for each instance """ self.check_shape(attribution, attribution) n_features = attribution.shape[1] spars = torch.zeros(attribution.shape[0], device=self.device) attrib_sum = torch.sum(torch.abs(attribution), axis=1) for j in range(1, n_features + 1): spars += (n_features - j + 0.5) * torch.abs(attribution[:, j - 1]) spars = 1 - 2 * spars / (attrib_sum * n_features + 1e-8) return torch.mean(spars, axis=0)
[docs] @time_function def compute_spar( model: Callable, rand_model: Callable, task: str, attributions: torch.Tensor, rand_attrib: torch.Tensor, rand_model_attributions: torch.Tensor, metrics: dict, device="cpu", ) -> dict: """Compute the sparseness metric. Args: model (callable): base model rand_model (callable): reference model (random model) task (str): task of the model attributions (torch.Tensor): attributions for base model and base explainer rand_attrib (torch.Tensor): random attributions rand_model_attributions (torch.Tensor): attributions for reference model and base explainer metrics (dict): dictionary of metrics device (str, optional): device to use. Defaults to "cpu". Returns: dict: dict of metrics """ spars = Sparseness(model, task, device) spars_score = spars.get_sparsity(torch.sort(torch.abs(attributions), axis=1)[0]) metrics["Sparseness"]["original"] = spars_score.item() if rand_attrib is not None: rand_spars_score = spars.get_sparsity( torch.sort(torch.abs(rand_attrib), axis=1)[0] ) metrics["Sparseness"]["random"] = rand_spars_score.item() if rand_model is not None: randmodel_spars = Sparseness(rand_model, task, device) randmodel_spars_score = randmodel_spars.get_sparsity( torch.sort(torch.abs(rand_model_attributions), axis=1)[0] ) metrics["Sparseness"]["random_model"] = randmodel_spars_score.item() return metrics