{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Get explainability metrics\n", "\n", "This notebook is used to get explainability metrics for the model based on the previously trained model and the computed attributions. The configuration used is the same as the [starter](starter.ipynb) and [explain](explain.ipynb) notebooks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model and data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../\")\n", "\n", "import torch\n", "import numpy as np\n", "\n", "from beexai.dataset.dataset import Dataset\n", "from beexai.dataset.load_data import load_data\n", "from beexai.evaluate.metrics.get_results import get_all_metrics\n", "from beexai.explanation.explaining import CaptumExplainer\n", "from beexai.training.train import Trainer\n", "from beexai.utils.time_seed import set_seed\n", "from beexai.utils.sampling import stratified_sampling\n", "\n", "set_seed(42)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "DATA_NAME = \"kickstarter\"\n", "MODEL_NAME = \"NeuralNetwork\"\n", "CONFIG_PATH = f\"config/{DATA_NAME}.yml\"\n", "data_test, target_col, task, _ = load_data(from_cleaned=True, config_path=CONFIG_PATH)\n", "scale_params = {\n", " \"x_num_scaler_name\": \"quantile_normal\",\n", " \"x_cat_encoder_name\": \"ordinalencoder\",\n", " \"y_scaler_name\": \"labelencoder\", # change to minmax or another float scaler for regression\n", " \"cat_not_to_onehot\": [\"name\"],\n", "}\n", "data = Dataset(data_test, target_col)\n", "X_train, X_test, y_train, y_test = data.get_train_test(\n", " test_size=0.2, scaler_params=scale_params\n", ")\n", "NUM_LABELS = data.get_classes_num(task)\n", "\n", "NN_PARAMS = {\"input_dim\": X_train.shape[1], \"output_dim\": NUM_LABELS}\n", "trainer = Trainer(MODEL_NAME, task, NN_PARAMS, device)\n", "trainer.load_model(f\"../output/models/{DATA_NAME}/{MODEL_NAME}.pt\")\n", "\n", "TEST_SIZE = 100\n", "X_test, y_test = stratified_sampling(X_test, y_test, TEST_SIZE, task)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Dummy\n", "\n", "Train a dummy model on shuffled labels to compare with the real model explanations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DUMMY_TRAIN_SIZE = 100\n", "X_train_sampled, y_train_sampled = stratified_sampling(\n", " X_train, y_train, DUMMY_TRAIN_SIZE, task\n", ")\n", "rand_trainer = Trainer(MODEL_NAME, task, NN_PARAMS, device)\n", "X_perm, y_perm = X_train_sampled.values, np.random.permutation(y_train_sampled.values)\n", "rand_trainer.train(X_perm, y_perm)\n", "rand_trainer.model.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get metrics\n", "\n", "`get_all_metrics` method automatically computes XAI metrics for each instance. Some metrics can be customized by specifying hyperparameters, those used in our paper are set by default.\n", "\n", "We also provide the possibility of comparison with random attributions and attributions obtained with the dummy model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "METHOD = \"IntegratedGradients\"\n", "exp = CaptumExplainer(\n", " trainer.model, task=task, method=METHOD, sklearn=False, device=device\n", ")\n", "exp.init_explainer()\n", "attributions = exp.compute_attributions(\n", " X_test, DATA_NAME, MODEL_NAME, METHOD, \"../output/explain/\", use_abs=False\n", ")\n", "\n", "rand_exp = CaptumExplainer(\n", " rand_trainer.model, task=task, method=METHOD, sklearn=False, device=device\n", ")\n", "rand_exp.init_explainer()\n", "rand_attributions = rand_exp.compute_attributions(\n", " X_test, DATA_NAME, MODEL_NAME, METHOD, \"../output/explain/\", use_abs=False\n", ")\n", "\n", "all_preds = trainer.model.predict(X_test.values)\n", "\n", "get_all_metrics(\n", " X_test,\n", " all_preds,\n", " trainer.model,\n", " exp,\n", " rand_trainer.model,\n", " rand_exp,\n", " print_plot=True,\n", " auc_metric=\"accuracy\",\n", " device=device,\n", " use_random=True,\n", " use_ref=True,\n", " attributions=attributions,\n", " attributions_ref=rand_attributions,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }