Source code for beexai.evaluate.metrics.complexity

from typing import Callable

import torch

from beexai.evaluate.metrics.metrics import CustomMetric
from beexai.utils.time_seed import time_function


[docs] class Complexity(CustomMetric): """Implementation of the complexity metric. Computes the complexity of the model by taking the entropy of the fractional contribution of each feature. References: - `Evaluating and Aggregating Feature-based Model Explanations <https://arxiv.org/abs/2005.00631>` Attributes: model (callable): model to explain task (str): task to perform device (str): device to use Methods: get_cmpl: computes the complexity of the model """ def __total_contribution__(self, attribution: torch.Tensor) -> torch.Tensor: """Compute the total contribution of each instance.""" return torch.sum(torch.abs(attribution), axis=1) def __fractional_contribution__( self, attribution: torch.Tensor, feature_i: int ) -> torch.Tensor: """Compute the fractional contribution of a given feature""" total_contrib = self.__total_contribution__(attribution) return torch.abs(attribution[:, feature_i]) / (total_contrib + 1e-8)
[docs] def get_cmpl(self, attribution: torch.Tensor) -> torch.Tensor: """Computes the complexity of the model. Args: attribution (torch.Tensor): attributions for each instance Returns: torch.Tensor: array of complexity scores for each instance """ self.check_shape(attribution, attribution) n_features = attribution.shape[1] complexity = torch.zeros(attribution.shape[0], device=self.device) for j in range(n_features): frac_contrib = self.__fractional_contribution__(attribution, j) # **PREVIOUS IMPLEMENTATION LESS EFFICIENT BUT MORE READABLE** # if frac_contrib[i] == 0: # complexity[i] += 0 # else: # complexity[i] += -frac_contrib[i]*np.log(frac_contrib[i]+1e-8) complexity += -frac_contrib * torch.log(frac_contrib + 1e-8) complexity = complexity / n_features return torch.mean(complexity, axis=0)
[docs] @time_function def compute_complex( model: Callable, rand_model: Callable, task: str, attributions: torch.Tensor, rand_attrib: torch.Tensor, randmodel_attributions: torch.Tensor, metrics: dict, device: str = "cpu", ) -> dict: """Computes the complexity of the base model, the random explainer and the random model. Args: model (callable): base model rand_model (callable): reference model (random model) task (str): task to perform attributions (torch.Tensor): feature attributions rand_attrib (torch.Tensor): random attributions randmodel_attributions (torch.Tensor): attributions of the random model metrics (dict): dictionary of metrics device (str, optional): device to use. Defaults to "cpu" Returns: dict: dict of metrics """ comp = Complexity(model, task, device) comp_score = comp.get_cmpl(attributions) metrics["Complexity"]["original"] = comp_score.item() if rand_attrib is not None: rand_comp_score = comp.get_cmpl(rand_attrib) metrics["Complexity"]["random"] = rand_comp_score.item() if rand_model is not None: randmodel_comp = Complexity(rand_model, task, device) randmodel_comp_score = randmodel_comp.get_cmpl(randmodel_attributions) metrics["Complexity"]["random_model"] = randmodel_comp_score.item() return metrics