Source code for beexai.explanation.plot_attr

"""Plotting functions for feature attributions"""

from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns


[docs] def bar_plot( attributions: np.ndarray, feature_names: Optional[list] = None, mean: bool = False ) -> None: """Plot a bar plot for the given attributions Args: attributions (np.ndarray): attributions feature_names (Optional[list], optional): name of the features. Defaults to None. If None is provided, they will be named as "Feature i". mean (bool, optional): whether to average the attributions. Defaults to False. """ if feature_names is None: feature_names = [f"Feature {i}" for i in range(attributions.shape[0])] if mean or attributions.ndim > 1 and attributions.shape[0] > 1: attributions = attributions.mean(axis=0) _, ax = plt.subplots(figsize=(10, 10)) bars = ax.barh(feature_names, attributions) ax.set_xlabel("Attributions") ax.set_ylabel("Features") ax.set_title("Feature attributions") ax.bar_label(bars, fmt="%.5f") for i, plot_bar in enumerate(bars): if attributions[i] < 0: plot_bar.set_color("r") plt.margins(0.1) plt.show()
[docs] def plot_waterfall( attributions: np.ndarray, feature_names: Optional[list] = None, mean: bool = False, ) -> None: """Plot a waterfall plot for the given attributions Args: attributions (np.ndarray): attributions feature_names (Optional[list], optional): name of the features. Defaults to None. If None is provided, they will be named as "Feature i". mean (bool, optional): whether to average the attributions. Defaults to False. """ if feature_names is None: feature_names = [f"Feature {i}" for i in range(attributions.shape[0])] if mean or attributions.ndim > 1 and attributions.shape[0] > 1: attributions = attributions.mean(axis=0) fig = go.Figure( go.Waterfall( orientation="h", measure=["relative" for _ in feature_names], x=attributions, textposition="outside", text=[ f"{name}: {value:.5f}" for name, value in zip(feature_names, attributions) ], y=feature_names, decreasing={"marker": {"color": "red"}}, increasing={"marker": {"color": "green"}}, totals={"marker": {"color": "blue"}}, connector={"line": {"color": "black"}}, ) ) fig.update_layout(title="Feature attributions") fig.show()
[docs] def plot_swarm( x_in: pd.DataFrame, attributions: np.ndarray, feature_names: Optional[list] = None ) -> None: """Plot a swarm plot for the given attributions Args: x_in (pd.DataFrame): input data attributions (np.ndarray): attributions feature_names (Optional[list], optional): name of the features. Defaults to None. If None is provided, they will be named as "Feature i". """ if feature_names is None: feature_names = [f"Feature {i}" for i in range(attributions.shape[0])] fig, axs = plt.subplots(len(feature_names), 1, figsize=(10, 10)) norm = plt.Normalize( vmin=min(x_in.values.flatten()), vmax=max(x_in.values.flatten()) ) for i, feature in enumerate(feature_names): x_in[feature] = x_in[feature].astype(float) ax = axs[i] a = attributions[:, i] c = x_in[feature].values key = [feature] * len(a) df = pd.DataFrame({"key": key, "a": a, "c": c}) cmap = sns.color_palette("coolwarm", as_cmap=True) colors = {} for cval in df["c"].unique(): colors.update({cval: cmap((cval))}) sns.stripplot( x="a", y="key", hue="c", data=df, palette=colors, ax=ax, orient="h", jitter=0.2, ) ax.legend_.remove() ax.set_ylabel("") ax.set_xlabel("Attribution values") fig.colorbar( plt.cm.ScalarMappable(cmap=cmap), ax=axs, orientation="vertical", norm=norm ) plt.show()