Skip to content
Snippets Groups Projects
TP_SETI.ipynb 25.1 KiB
Newer Older
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6bbe15d5-29f8-4a15-b35f-c19ee16da6aa",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "metadata": {},
   "source": [
    "# TP Explainable AI at SET\n",
    "\n",
    "\n",
    "This tutorial aims to provide an overview on the most popular techniques of explainable AI (xAI). As we saw during the presentation, we can broadly divide those techniques in two kinds:\n",
    "1. _Post-hoc_ explanation methods, that are used to analyze existing models\n",
    "2.  _by-design_ explainable models, programs that embed explanations into their decision process\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "For the _post-hoc_ methods, we will use the [Captum](https://captum.ai/) library. Part of this tutorial is adapted from the CAPTUM [original tutorial on CIFAR10](https://captum.ai/tutorials/CIFAR_TorchVision_Interpret).\n",
    "For the by-design model, we will use the [CaBRNet](https://git.frama-c.com/pub/cabrnet) library, developped at CEA."
  {
   "cell_type": "markdown",
   "id": "5421bf3a-01ad-4817-8a6d-5a3078d0569c",
   "metadata": {},
   "source": [
    "## Legal notice and acknowledgements\n",
    "This tutorial is under license [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/deed.fr).\n",
    "\n",
    "The main author is Julien Girard-Satabin.\n",
    "Thank Alban Grastien, Aymeric Varasse and Romain Xu-Darme for their valuable feedback. "
   ]
  },
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
  {
   "cell_type": "markdown",
   "id": "9e029cea-12d7-4a54-9e68-7a9f2df14521",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "## Preliminaries\n",
    "\n",
    "### Environment setup\n",
    "\n",
    "Install all dependencies in a dedicated virtual environment. A `setup.sh` script is provided at the root of the session repository. This section ensures that the downloaded packages are correctly setup, and that the pretrained models behave as expected.\n"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "39d401d6-eb75-4147-be9d-1ec863b3d857",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import ipyplot\n",
    "import yaml\n",
    "from IPython.core.display import SVG\n",
    "from tqdm import tqdm\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "%matplotlib inline\n",
    "\n",
    "import captum\n",
    "from captum.attr import visualization as viz\n",
    "\n",
    "import torch\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "import torchvision\n",
    "import torchvision.transforms.v2 as transforms\n",
    "\n",
    "import cabrnet \n",
    "from zenodo_get import zenodo_get\n",
    "from IPython.display import IFrame, Image, SVG, display"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "975e7c31-ad96-431c-a4c9-f9ea097f3f27",
   "metadata": {},
   "source": [
    "We will use for this session a reduced image set of the dataset [CUB200](http://www.vision.caltech.edu/datasets/cub_200_2011/). This is to avoid unecessary training time and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "38e1ad69-ee58-4bb3-b913-c5edc13ea8a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean=[0.485, 0.456, 0.406]\n",
    "std=[0.229, 0.224, 0.225]\n",
    "\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "transform = transforms.Compose(\n",
    "    [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),\n",
    "     transforms.Resize((224,224),antialias=True),\n",
    "     transforms.Normalize(mean=mean,std=std),\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "    ])\n",
    "tinyCub = torchvision.datasets.ImageFolder(root=\"./data/test_tiny\", transform=transform)# The model was trained on normalized images to improve its performance. Therefore, normalization must also be applied on test images.\n",
    "def denormalize(x):\n",
    "    # Reverse the normalization operation to recover original image\n",
    "    return x * torch.tensor(std).view(-1,1,1) + torch.tensor(mean).view(-1,1,1)"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "746d2d8b-dfd5-498b-94cc-30ff03d21c76",
   "metadata": {},
   "source": [
    " We will also load a pretrained model on Cub200 (a ResNet 50) for Post-Hoc explanations."
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "3c81c3fd-7e10-43f4-8ef8-02759d5329db",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelPostHoc = torch.load('./models/r50_CUB200_i448.pth',map_location='cpu')\n",
    "# Put the model in evaluation mode \n",
    "# Very important to avoid side-effects such as unwanted parameter modification in the model (e.g. batch normalization)\n",
    "modelPostHoc.eval()"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f92f76da-41cc-4f13-a98c-f72c1e033672",
   "metadata": {},
   "source": [
    "### Sanity checks\n",
    "\n",
    "We will begin by loading some images from the dataset, pass them through the model and see if the predictions are correct. "
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f33348-848b-4d83-a539-03abebc7f786",
   "metadata": {},
   "outputs": [],
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "def imshow(img):\n",
    "    img = denormalize(img)\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "batch_size = 5\n",
    "loader = torch.utils.data.DataLoader(tinyCub, batch_size=batch_size)\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "classes = list(map(lambda x: x.split(\".\")[1], tinyCub.classes))\n",
    "[imgs, targets] = next(iter(loader))\n",
    "res = modelPostHoc(imgs)\n",
    "imshow(torchvision.utils.make_grid(imgs,nrow=batch_size))\n",
    "print(f\"Ground truth predictions:  {' ' .join('%2s' % targets[j].item()+ ' ' + classes[targets[j]] for j in range(batch_size) )}\")\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "_, predicted = torch.max(res, 1)\n",
    "print(f\"Predicted classes:  {' ' .join('%2s' % predicted[j].item()+ ' ' + classes[predicted[j]] for j in range(batch_size) )}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbe2d0aa-aea6-43b2-be32-3cfe8064dacb",
   "metadata": {},
   "source": [
    "Finally, we will compute the average precision on the dataset. We should have an accuracy of about 69%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69df3f34-67cd-4c7e-8645-1d5542208998",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, (img, target) in enumerate(tqdm(loader)):\n",
    "    _, predicted = torch.max(modelPostHoc(img), 1)\n",
    "    batch_acc = (torch.sum((predicted==target))*True).item()\n",
    "print(f\"Accuracy: {acc/len(tinyCub)*100:.2f}%\")"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "426864ec-c9c2-421f-bced-9cf60fefae52",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "## Post Hoc explanation methods\n",
    "\n",
    "All the following approaches aim to do _feature attribution_: given a sample $x$ with features $x^i$, a program $f$ and a prediction $y$, the aim is to answer the following question: \"which $x^i$ contributed the most to $f(x)=y$? \n",
    "\n",
    "![](post-hoc.png)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f27de05e-769d-48b6-9f5d-9b97b37ba170",
   "metadata": {},
   "source": [
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "### Saliency maps\n",
    "\n",
    "We will first apply the simplest attribution method: [backpropagating the gradient](https://arxiv.org/abs/1312.6034) of $y$ on the chosen sample $x$:\n",
    "$\\frac{\\partial{f(x)}}{\\partial{x}}$\n",
    "It is done automatically with most of modern deep learning frameworks.\n",
    "Note that you can change the `sign` parameter to `\"all\"` to see the sign variations for all following methods. "
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3dfa9fa-7e75-42bc-ad52-c534f91f93b9",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "metadata": {},
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "input, target = tinyCub[0]\n",
    "input = input.unsqueeze(0)\n",
    "input.requires_grad = True # Indicate that gradients should be propagated back to this tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2afb95e-6bea-411f-865c-d83fbe1a32f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "original_image = np.transpose(denormalize(input[0]).detach().cpu().numpy(),(1, 2, 0))\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "saliency = captum.attr.Saliency(modelPostHoc)\n",
    "grads = saliency.attribute(inputs=input, target=target, abs=False)\n",
    "grads = np.transpose(grads.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
    "_ = viz.visualize_image_attr(None, original_image, \n",
    "                      method=\"original_image\", title=\"Original Image\")\n",
    "_ = viz.visualize_image_attr(grads, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
    "                             outlier_perc=5, show_colorbar=True, \n",
    "                             title=\"Overlayed Gradient Magnitudes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc186f4e-7725-4b98-abfa-cee5eed947f9",
   "metadata": {},
   "source": [
    "We see that the gradients focus a lot on the head, wings and tail; but also on the top corners of the image. Altough it may describe how the neural network take its decision, it may not match the human decision process to classify an albatross."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d0bffdf-27c7-4930-8ac7-c94cdc1e0c3f",
   "metadata": {},
   "source": [
    "### Saliency maps with SmoothGrads\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "Given $x$, [SmoothGrads](https://arxiv.org/abs/1706.03825) aims to compute an average of the gradients in a neighborhood $x^{*}$ to reduce the influence of sharp, local variations. An approximation of this averaged gradient can be computed by the following:\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "$$\n",
    "\\nabla_{x^{*}}y \\approx \\frac{1}{n}\\sum_0^{n}\\nabla_xf(x+\\mathcal{N}(0,\\sigma))\n",
    "$$\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "There are two parameters here:\n",
    "1. $\\sigma$: the standard deviation of the gaussian sampling\n",
    "2. $n$: the number of samples computed by smoothgrad \n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "Experiment by changing those parameters and calling the `attribute` method (it may take long if you increase the number of samples: start by increments of 5).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b0462e3-bb46-49bf-aced-20569e63468b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 10 # Number of perturbed samples per step\n",
    "sigma = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c271261-78ee-4e1e-97b1-c43c0500e085",
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency = captum.attr.Saliency(modelPostHoc)\n",
    "nt = captum.attr.NoiseTunnel(saliency)\n",
    "attrs = nt.attribute(inputs=input, target=target, nt_type='smoothgrad_sq', nt_samples=n_samples, stdevs=sigma)\n",
    "attrs= np.transpose(attrs.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
    "_ = viz.visualize_image_attr(attrs, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
    "                             outlier_perc=10, show_colorbar=True, \n",
    "                             title=\"Overlayed Gradient Magnitudes \\n with SmoothGrad Squared\")"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f504bb5e-6a3a-4a64-aec8-fe62b38e16aa",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "metadata": {},
   "source": [
    "With averaged gradients, the interpretation seems much less noisy. With a sufficiently high number of samples and a low standard deviation, the gradient seems to vary a lot around the head."
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "121f7611-0f58-4b9c-b895-f8ca7142944a",
   "metadata": {},
   "source": [
    "### Integrated gradients\n",
    "\n",
    "The previous approaches have limitations. Namely, they exist some situations where the gradient of different values is the same. \n",
    "To tackle this issue, [integrated Gradients](https://arxiv.org/abs/1703.01365) computes a linear approximation of the gradient on the line between an baseline image $x^{'}$ and the image $x$.\n",
    "\n",
    "$$\n",
    "IG_i = (x_i - x^{'}_i) \\int_{\\alpha=0}^{1} \\nabla_{x_i}\n",
    "f(x^{'}+\\alpha(x-x^{'}))d\\alpha\n",
    "$$"
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2b779ba-05e4-4cfe-925c-4668644441d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 20 # Number of Integrated Gradients steps\n",
    "n_samples = 10  # Number of samples\n",
    "sigma = 0.2"
   ]
  },
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "23b17d92-3630-454f-accd-6dd33435faa1",
   "metadata": {},
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "ig = captum.attr.IntegratedGradients(modelPostHoc)\n",
    "attributions, delta = ig.attribute(inputs=input,  baselines=input*0, n_steps=n_steps, target=target, return_convergence_delta=True)\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "attributions = np.transpose(attributions.squeeze().cpu().detach().numpy(), (1, 2, 0))\n",
    "print('Approximation delta: ', abs(delta))\n",
    "_ = viz.visualize_image_attr(attributions, original_image, method=\"blended_heat_map\",sign=\"absolute_value\",\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "                          show_colorbar=True, title=\"Overlayed Integrated Gradients\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c3dac7c-cfe8-4a36-9ff1-71757ae99cdc",
   "metadata": {},
   "source": [
    "The Integrated Gradients display how much variations (in term of gradient) exist between a white image and the actual image. Note that the heatmaps are now located on the wings.\n",
    "We will now combine Integrated Gradients with SmoothGrads."
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "bef64173-1077-4ded-8871-055ea14b628f",
   "metadata": {},
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "ig = captum.attr.IntegratedGradients(modelPostHoc)\n",
    "nt = captum.attr.NoiseTunnel(ig)\n",
    "attributions_smoothgrad = nt.attribute(inputs=input, baselines=input * 0, target=target, nt_type='smoothgrad_sq', n_steps=n_steps, nt_samples=n_samples, stdevs=sigma)\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "attributions_smoothgrad = np.transpose(attributions_smoothgrad.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
    "_ = viz.visualize_image_attr(attributions_smoothgrad, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
    "                             outlier_perc=10, show_colorbar=True, \n",
    "                             title=\"Overlayed Integrated Gradients \\n with SmoothGrad Squared\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "113772f6-dae3-4622-9019-0cc96ea8556c",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "metadata": {},
   "source": [
    "We note that integrated gradients with smoothgrads provide much more focused variations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e606715-69ee-490d-a3ff-61f3890a42b9",
   "metadata": {},
   "source": [
    "You can change the image and rerun the experiments to see how those two approaches vary.\n",
    "\n",
    "\n",
    "Overall, we note that with these three approaches we obtain seemingly similar results. But the following questions remain:\n",
    "\n",
    "* why an explanation method chose this particular zone of the image\n",
    "* how can we state that one explanation method is more representative of the network behaviour than the other\n",
    "* what do we do of the explanations?"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2cb7053-d9f3-4d55-9ed0-4fc3f54b94ad",
   "metadata": {},
   "source": [
    "## Explainable by design: ProtoTree with the CaBRNet library\n",
    "\n",
    "\n",
    "We will now look at another class of interpretability models: interpretable by-design models. We will focus on ProtoTrees. \n",
    "We will study the [ProtoTree](https://arxiv.org/abs/2012.02046) architecture. \n",
    "\n",
    "\n",
    "Some discussion about ProtoTree, namely the parameters we will consider:\n",
    "* tree depth\n",
    "* the effect of pruning\n",
    "\n",
    "\n",
    "![](prototree.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6ee56d4-648a-4398-84ef-d16564f0dc28",
   "metadata": {},
   "source": [
    "### Preliminary\n",
    "\n",
    "We downloaded the model and the corresponding generated prototypes. For this session, we also provided pre-made configuration files.\n",
    "First, instanciate the model and the config files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "240c99cb-d79a-4d15-b89e-0d9fa56c6eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper functions to quickly change yaml files\n",
    "\n",
    "def find_dict_with_key(adict,key):\n",
    "    stack = [adict]\n",
    "    while stack:\n",
    "        d = stack.pop()\n",
    "        if not d:\n",
    "            raise ValueError(f\"Key {key} not found in YAML, be sure to check it exists\")\n",
    "        elif key in d:\n",
    "            return d\n",
    "        for k,v in d.items():\n",
    "            if isinstance(v, dict):\n",
    "                stack.append(v)\n",
    "\n",
    "def replace_by(f,k,ov,nv):\n",
    "    with open(f, \"r+\") as read:\n",
    "        y = yaml.safe_load(read)\n",
    "        d = find_dict_with_key(y,k)\n",
    "        if d[k] == ov:\n",
    "            d[k] = nv\n",
    "        with open(f+\".modified.yml\",\"w+\") as write:\n",
    "            yaml.dump(y,write)\n",
    "            return(f+\".modified.yml\")"
   "execution_count": null,
   "id": "de765e35-4708-47da-9a8f-de4d4c77593e",
   "metadata": {},
   "source": [
    "# Instanciation of paths \n",
    "from cabrnet.generic.model import ProtoClassifier\n",
    "root_cabrnet_config=os.path.join(\"models\",\"cabrnet_seti\")\n",
    "root_model=os.path.join(root_cabrnet_config,\"model\")\n",
    "root_protos=os.path.join(root_cabrnet_config,\"prototypes\")\n",
    "root_out=os.path.join(\"outs\")\n",
    "\n",
    "# Configuration files, we remove the spurious default weight\n",
    "path_to_model_config = replace_by(os.path.join(root_model,\"model.yml\"),\"weights\",\"examples/pretrained_conv_extractors/resnet50_inat.pth\",None)\n",
    "\n",
    "path_to_state_dict=os.path.join(root_model,\"model_state.pth\")\n",
    "path_to_pruned_state_dict=os.path.join(root_model,\"pruned_model_state.pth\")\n",
    "model = ProtoClassifier.build_from_config(config_file=path_to_model_config,state_dict_path=path_to_state_dict)\n",
    "\n",
    "img_path =os.path.join(\"data\",\"cub_test_tiny\",\"001.Black_footed_Albatross\",\"Black_Footed_Albatross_0051_796103.jpg\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f8e4b5-7a84-4039-a7ce-22844c828bd4",
   "metadata": {},
   "source": [
    "We loaded a pretrained ProtoTree using CaBRNet, as well as two configuration files. Let us look at `model.yml`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fb1cf1d-d883-44ba-95af-5a097d7c6667",
   "metadata": {},
   "source": [
    "!cat $path_to_model_config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c1808b-ef55-414c-991f-1d4104f01228",
   "metadata": {},
   "source": [
    "This file defines the architecture of a ProtoTree. Consider the _classifier_ section. Among several parameters, we define  `depth`: it is the depth of the soft decision tree used in ProtoTree. The higher this parameter, the deeper the tree will be (and thus higher the number of prototypes). Here, 9 was chosen after cross-validation on this dataset. We will examine the influence of changing the depth on another model.\n",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
    "\n",
    "Note that we did not put anything under the \"weights\" section, as we are loading an already pretrained model through the `model_state.pth`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c4b5e92-fb01-4ebc-8c52-808356d40d0b",
   "metadata": {},
   "source": [
    "#### Evaluate the ProtoTree performance\n",
    "\n",
    "The snippet of code below calls the CaBRNet `evaluate` method on the model to perform a basic inference and collect some stats. This should take less than one minute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01ca872f-5156-4373-af90-10cd5ed07e54",
   "metadata": {},
    "stats = model.evaluate(dataloader=loader, device='cpu', verbose=True)\n",
    "for name, value in stats.items():\n",
    "    print(f\"{name}: {value:.3f}\")\n"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b373602f-9743-4975-857a-1f920ce8b1a3",
   "metadata": {},
   "source": [
    "The accuracy should be above $0.98$. For this test set, the ProtoTree has a similar performance compared to a classical model. It brings the additionnal benefit of being interpretable, as we will see now. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "433d32db-a9a3-48a6-bb86-88bdb3adb33d",
   "metadata": {},
   "source": [
    "#### Explain local\n",
    "\n",
    "We will first examine the inference pipeline of a ProtoTree. We will need\n",
    "* a specific image with the same preprocessing used during the ProtoTree's training\n",
    "* a model\n",
    "* a way to visualize the similarity computed at each node\n",
    "\n",
    "We have a pre-configured configuration file visualizer under `path_to_visu_config`.\n",
    "Let us consider the `test_patch` section. \n",
    "\n",
    "We are interested in the `retrace` key. It describes the function that is used to visualize the patch that correspond to the prototype. Among the various options, we have:\n",
    "* `cubic_upsampling` that takes the following parameters:\n",
    "  * `normalize` (bool, default False)\n",
    "  * `single_location` (bool, default True)\n",
    "* `smoothgrad` that takes the following parameters:\n",
    "  * `polarity` (string, default \"absolute\")\n",
    "  * `gaussian_ksize` (int, default 5)\n",
    "  * `normalize` (bool, default False)\n",
    "  * `grad_x_input` (bool, default False)\n",
    "* `randgrad` with the same arguments as `smoothgrads`"
   "execution_count": null,
   "id": "9962fbbc-1837-4841-b092-c02cd2593024",
   "metadata": {
    "scrolled": true
   },
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "source": [
    "path_to_visu_config=os.path.join(root_protos,\"visualization.yml\")\n",
    "!cat $path_to_visu_config"
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "id": "f728b98e-ccc7-46cc-8edf-d34be6181147",
   "metadata": {},
   "source": [
    "from cabrnet.generic.model import SimilarityVisualizer\n",
    "!rm -rf $root_out/test_patches # removing existing folder\n",
    "path_to_visu_config = replace_by(os.path.join(root_protos,\"visualization.yml\"), \"type\",\"cubic_upsampling\",\"smoothgrad\")\n",
    "!cat $path_to_visu_config\n",
    "visualizer = SimilarityVisualizer.build_from_config(config_file=path_to_visu_config,target=\"test_patch\")\n",
    "model.explain(prototype_dir_path=root_protos,output_dir_path=root_out,img_path=img_path,preprocess=transform,device=\"cpu\",visualizer=visualizer)\n",
    "imgs = [Image(filename=os.path.join(root_out,\"test_patches\",i)) for i in os.listdir(os.path.join(root_out,\"test_patches\"))]\n",
    "display(*imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9987db33-9f1d-483b-b8be-5ef47e63e196",
   "metadata": {},
   "source": [
    "#### Explain global\n",
    "\n",
    "Given extracted prototypes, provide the inference of a ProtoTree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e91c919-9f4c-407b-a2db-733e297abd4f",
   "metadata": {},
   "source": [
    "model = ProtoClassifier.build_from_config(config_file=path_to_model_config,state_dict_path=path_to_state_dict)\n",
    "model.explain_global(prototype_dir_path=root_protos,output_dir_path=root_out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "143ac5a9-e328-45b3-8c70-50ecef8794c8",
   "metadata": {},
   "source": [
    "IFrame(os.path.join(root_out,\"global_explanation.pdf\"), width=800, height=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d360db4-17dc-422e-af45-3f12334227d3",
   "metadata": {},
   "source": [
    "### On the effect of pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c166b54-591b-4132-9887-95500c9904a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: \n",
    "# * load the model with no pruning\n",
    "# * use protolib.model.prune() with several threshold\n",
    "# * provide global explanation with several thresholds "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7236bda-0073-45ec-bed9-437b1fa2b19f",
   "metadata": {},
   "source": [
    "\n",
    "<div style=\"color:red\"> TODO: by-design models are a bit more cumbersome to train and use, but they provide an easier to grasp decision process </div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22304e49-a9d8-46aa-97c0-2853541a44b1",
   "metadata": {},
   "source": [
    "### "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59aa9578-1d35-485c-9a76-4238d6f508cc",
   "metadata": {},
   "source": [
    "### "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba630aea-1636-48c4-85ec-3a652b0aed2c",
   "metadata": {},
   "outputs": [],
   "source": []
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "setixaitp",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
   "language": "python",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Julien Girard-Satabin's avatar
Julien Girard-Satabin committed
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}