Source code for beexai.evaluate.plot_metric

"""Radar plot for explanation metrics"""

from math import pi

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 16

plt.rc("font", size=MEDIUM_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=BIGGER_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

linestyles = [
    ("solid", "solid"),
    ("dotted", "dotted"),
    ("dashed", "dashed"),
    ("dashdot", "dashdot"),
    ("densely dotted", (0, (1, 1))),
    ("dashed", (0, (5, 5))),
    ("densely dashed", (0, (5, 1))),
    ("dashdotted", (0, (3, 5, 1, 5))),
    ("densely dashdotted", (0, (3, 1, 1, 1))),
    ("dashdotdotted", (0, (3, 5, 1, 5, 1, 5))),
    ("densely dashdotdotted", (0, (3, 1, 1, 1, 1, 1))),
]

color_palette = [
    "#66C5CC",
    "#F6CF71",
    "#F89C74",
    "#87C55F",
    "#DCB0F2",
    "#9EB9F3",
]

metrics_plot_rename = {
    "FaithCorr_1-": "Faith ↗",
    "Sensitivity_0+": "Sens ↘",
    "Infidelity_0+": "Inf ↘",
    "Comprehensiveness_1-": "Compr ↗",
    "Sufficiency_0+": "Suff ↘",
    "AUC_TP_0+": "AUC-TP ↘",
    "Monotonicity_1-": "Mono ↗",
    "Complexity_0+": "Compl ↘",
    "Sparseness_1-": "Spar ↗",
}

metrics_range = {
    "FaithCorr_1-": (0.0, 1.0),
    "Sensitivity_0+": (1.0, 0.0),
    "Infidelity_0+": (1.0, 0.0),
    "Comprehensiveness_1-": (0.0, 1.0),
    "Sufficiency_0+": (1.0, 0.0),
    "AUC_TP_0+": (1.0, 0.0),
    "Complexity_0+": (1.0, 0.0),
    "Sparseness_1-": (0.0, 1.0),
    "Monotonicity_1-": (0.0, 1.0),
}


def _invert(x, limits):
    """inverts a value x on a scale from
    limits[0] to limits[1]"""
    return limits[1] - (x - limits[0])


def _scale_data(data, ranges):
    """scales data[1:] to ranges[0],
    inverts if the scale is reversed"""
    for d, (y1, y2) in zip(data[1:], ranges[1:]):
        assert (y1 <= d <= y2) or (y2 <= d <= y1), f"d={d}, y1={y1}, y2={y2}"
    x1, x2 = ranges[0]
    d = data[0]
    if x1 > x2:
        d = _invert(d, (x1, x2))
        x1, x2 = x2, x1
    sdata = [d]
    for d, (y1, y2) in zip(data[1:], ranges[1:]):
        if y1 > y2:
            d = _invert(d, (y1, y2))
            y1, y2 = y2, y1
        sdata.append((d - y1) / (y2 - y1) * (x2 - x1) + x1)
    return sdata


[docs] class ComplexRadar: def __init__(self, fig, variables, ranges, n_ordinate_levels=6): angles = np.arange(0, 360, 360.0 / len(variables)) axes = [ fig.add_axes([0.1, 0.1, 0.9, 0.9], polar=True, label=f"axes{i}") for i in range(len(variables)) ] for label, i in zip(axes[0].get_xticklabels(), range(0, len(angles))): angle_rad = angles[i] if angle_rad <= pi / 2: ha = "left" va = "bottom" elif pi / 2 < angle_rad <= pi: ha = "right" va = "bottom" elif pi < angle_rad <= (3 * pi / 2): ha = "right" va = "top" else: ha = "right" va = "bottom" label.set_verticalalignment(va) label.set_horizontalalignment(ha) _, text = axes[0].set_thetagrids(angles, labels=variables) [txt.set_rotation(angle - 90) for txt, angle in zip(text, angles)] for ax in axes[1:]: ax.patch.set_visible(False) ax.grid("off") ax.xaxis.set_visible(False) for i, ax in enumerate(axes): grid = np.linspace(*ranges[i], num=n_ordinate_levels) gridlabel = [f"{round(x,2)}" for x in grid] gridlabel[0] = "" ax.set_rgrids(grid, labels=gridlabel, angle=angles[i]) ax.set_ylim(*ranges[i]) self.angle = np.deg2rad(np.r_[angles, angles[0]]) self.ranges = ranges self.ax = axes[0]
[docs] def plot(self, data, *args, **kw): sdata = _scale_data(data, self.ranges) self.ax.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kw)
[docs] def fill(self, data, *args, **kw): sdata = _scale_data(data, self.ranges) self.ax.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kw)
[docs] def plot_multiple(self, data, methods, *args, **kw): for i, d in enumerate(data): _, marker = linestyles[i] color = color_palette[i] sdata = _scale_data(d, self.ranges) self.ax.plot( self.angle, np.r_[sdata, sdata[0]], linestyle=marker, linewidth=3, c=color, *args, **kw, ) self.ax.tick_params(axis="both", which="major", pad=18) self.ax.legend(methods, loc="lower right", bbox_to_anchor=(1.4, 0.1))
[docs] def fill_multiple(self, data, *args, **kw): for i, d in enumerate(data): color = color_palette[i] sdata = _scale_data(d, self.ranges) self.ax.fill(self.angle, np.r_[sdata, sdata[0]], color=color, *args, **kw)
[docs] def get_dec_exponent(x): """Get the decimal exponent of a number x""" x_str = "{:e}".format(x) x_str = x_str.split("e") return float(x_str[0]), int(x_str[1])
[docs] def plot_metric( df_path, metrics_plot=[ "FaithCorr_1-", "Sensitivity_0+", "Infidelity_0+", "Comprehensiveness_1-", "Sufficiency_0+", "AUC_TP_0+", "Monotonicity_1-", "Complexity_0+", "Sparseness_1-", ], methods_plot=[ "Lime", "ShapleyValueSampling", "KernelShap", "DeepLift", "IntegratedGradients", "Saliency", ], plot_nn=True, save_path=None, alpha=0.2, ) -> None: """Plot the radar chart for the metrics. Args: df_path: path for the metrics dataframe metrics_plot: list of metrics to plot methods_plot: list of methods to plot plot_nn: whether to plot metric for Neural Network save_path: path to save the plot alpha: transparency of the plot """ metric_df = pd.read_csv(df_path) metric_df.index = metric_df["metrics"] metric_df = metric_df.loc[methods_plot] metrics_plot_copy = metrics_plot.copy() values = [] for i in range(len(metrics_plot)): col_name = metrics_plot[i] if not plot_nn: col_name = col_name + ".1" metrics_col = metric_df[metrics_plot[i]].dropna().tolist() metrics_col = [float(i) for i in metrics_col] values.append(metrics_col) values = np.array(values) ranges = [metrics_range[key] for key in metrics_plot] metrics_plot_re = [metrics_plot_rename[metric] for metric in metrics_plot] for i, r in enumerate(ranges): mean_std = np.std(values[i]) if mean_std < 0.005: dec_exp = get_dec_exponent(mean_std) values[i] = [x * 10 ** (-dec_exp[1] - 1) for x in values[i]] ranges[i] = (r[0] * 10 ** (-dec_exp[1] - 1), r[1] * 10 ** (-dec_exp[1])) metrics_plot_copy[i] = metrics_plot_re[i] + " (10e" + str(dec_exp[1]) + ")" else: metrics_plot_copy[i] = metrics_plot_re[i] max_val = max(values[i]) min_val = min(values[i]) margin = 2.2 * (max_val - min_val) if values.shape[1] > 1: ranges[i] = ( (max_val - margin, max_val) if r[0] < r[1] else (min_val + margin, min_val) ) else: ranges[i] = ( (max_val - 0.1, max_val) if r[0] < r[1] else (min_val + 0.1, min_val) ) fig1 = plt.figure(figsize=(10, 10)) radar = ComplexRadar(fig1, metrics_plot_copy, ranges) radar.plot_multiple(values.transpose(), methods_plot) radar.fill_multiple(values.transpose(), alpha=alpha) if save_path is not None: plt.savefig(save_path, bbox_inches="tight") else: plt.show()