Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"id": "78de9b8a-9382-4955-bd6c-e4067a2d160f",
"metadata": {},
"source": [
"# TP Explainable AI at SET\n",
"\n",
"\n",
"This tutorial aims to provide an overview on the most popular techniques of explainable AI (xAI). As we saw during the presentation, we can broadly divide those techniques in two kinds:\n",
"1. _Post-hoc_ explanation methods, that are used to analyze existing models\n",
"2. _by-design_ explainable models, programs that embed explanations into their decision process\n",
"For the _post-hoc_ methods, we will use the [Captum](https://captum.ai/) library. Part of this tutorial is adapted from the CAPTUM [original tutorial on CIFAR10](https://captum.ai/tutorials/CIFAR_TorchVision_Interpret).\n",
"For the by-design model, we will use the [CaBRNet](https://git.frama-c.com/pub/cabrnet) library, developped at CEA."
]
},
{
"cell_type": "markdown",
"id": "9e029cea-12d7-4a54-9e68-7a9f2df14521",
"metadata": {},
"source": [
"## Preliminaries\n",
"\n",
"### Environment setup\n",
"\n",
"Install all dependencies in a dedicated virtual environment. A `setup.sh` script is provided at the root of the session repository. This section ensures that the downloaded packages are correctly setup, and that the pretrained models behave as expected.\n"
"id": "39d401d6-eb75-4147-be9d-1ec863b3d857",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import ipyplot\n",
"from IPython.core.display import SVG\n",
"\n",
"%matplotlib inline\n",
"\n",
"import captum\n",
"from captum.attr import visualization as viz\n",
"\n",
"import torchvision\n",
"import torchvision.transforms.v2 as transforms\n",
"\n",
"import cabrnet \n",
"from zenodo_get import zenodo_get\n",
"from IPython.display import IFrame, Image, display"
]
},
{
"cell_type": "markdown",
"id": "975e7c31-ad96-431c-a4c9-f9ea097f3f27",
"metadata": {},
"source": [
"We will use for this session a reduced image set of the dataset [CUB200](http://www.vision.caltech.edu/datasets/cub_200_2011/). This is to avoid unecessary training time and inference."
]
},
{
"cell_type": "code",
"id": "38e1ad69-ee58-4bb3-b913-c5edc13ea8a9",
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose(\n",
" [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),\n",
" transforms.Resize((224,224),antialias=True),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"tinyCub = torchvision.datasets.ImageFolder(root=\"./data/cub_train_tiny\", transform=transform)"
]
},
{
"cell_type": "markdown",
"id": "746d2d8b-dfd5-498b-94cc-30ff03d21c76",
"metadata": {},
"source": [
" We will also load a pretrained model on Cub200 (a ResNet 50) for Post-Hoc explanations."
"id": "3c81c3fd-7e10-43f4-8ef8-02759d5329db",
"metadata": {},
"outputs": [],
"source": [
"modelPostHoc = torch.load('./models/r50_CUB200_i448.pth',map_location='cpu')"
]
},
{
"cell_type": "markdown",
"id": "f92f76da-41cc-4f13-a98c-f72c1e033672",
"metadata": {},
"source": [
"### Sanity checks\n",
"\n",
"We will begin by loading some images from the dataset, pass them through the model and see if the predictions are correct. "
"execution_count": null,
"id": "66f33348-848b-4d83-a539-03abebc7f786",
"metadata": {},
"outputs": [],
" img = img / 4.3 + 0.4 # hackish unnormalization\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.show()\n",
"\n",
"loader = torch.utils.data.DataLoader(tinyCub, batch_size=16)\n",
"classes = list(map(lambda x: x.split(\".\")[1], tinyCub.classes))\n",
"[imgs, targets] = next(iter(loader))\n",
"res = modelPostHoc(imgs)\n",
"imshow(torchvision.utils.make_grid(imgs,nrow=4))\n",
"print(f\"Ground truth predictions: {' ' .join('%2s' % targets[j].item()+ ' ' + classes[targets[j]] for j in range(5) )}\")\n",
"print(f\"Predicted classes: {' ' .join('%2s' % predicted[j].item()+ ' ' + classes[predicted[j]] for j in range(5) )}\")"
]
},
{
"cell_type": "markdown",
"id": "bbe2d0aa-aea6-43b2-be32-3cfe8064dacb",
"metadata": {},
"source": [
"Finally, we will compute the average precision on the dataset. We should have an accuracy of about 61%."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69df3f34-67cd-4c7e-8645-1d5542208998",
"metadata": {},
"outputs": [],
"source": [
"acc = 0\n",
"for idx, (img, target) in enumerate(loader):\n",
" _, predicted = torch.max(modelPostHoc(img), 1)\n",
" batch_acc = (torch.sum((predicted==target))*True).item()/16\n",
" acc += batch_acc\n",
"print(f\"Accuracy: {acc/idx*100:.2f}%\")"
"id": "426864ec-c9c2-421f-bced-9cf60fefae52",
"metadata": {},
"source": [
"## Post Hoc explanation methods\n",
"\n",
"All the following approaches aim to do _feature attribution_: given a sample $x$ with features $x^i$, a program $f$ and a prediction $y$, the aim is to answer the following question: \"which $x^i$ contributed the most to $f(x)=y$? \n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "f27de05e-769d-48b6-9f5d-9b97b37ba170",
"metadata": {},
"source": [
"We will first apply the simplest attribution method: [backpropagating the gradient](https://arxiv.org/abs/1312.6034) of $y$ on the chosen sample $x$:\n",
"$\\frac{\\partial{f(x)}}{\\partial{x}}$\n",
"It is done automatically with most of modern deep learning frameworks.\n",
"Note that you can change the `sign` parameter to `\"all\"` to see the sign variations for all following methods. "
"execution_count": null,
"id": "e3dfa9fa-7e75-42bc-ad52-c534f91f93b9",
"source": [
"input = imgs[0].unsqueeze(0)\n",
"target = targets[0]\n",
"input.require_grads = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2afb95e-6bea-411f-865c-d83fbe1a32f1",
"metadata": {},
"outputs": [],
"source": [
"original_image = np.transpose((imgs[0].cpu().detach().numpy() / 4.3) + 0.4, (1, 2, 0))\n",
"saliency = captum.attr.Saliency(modelPostHoc)\n",
"grads = saliency.attribute(inputs=input, target=0, abs=False)\n",
"grads = np.transpose(grads.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
"_ = viz.visualize_image_attr(None, original_image, \n",
" method=\"original_image\", title=\"Original Image\")\n",
"_ = viz.visualize_image_attr(grads, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
" outlier_perc=5, show_colorbar=True, \n",
" title=\"Overlayed Gradient Magnitudes\")"
]
},
{
"cell_type": "markdown",
"id": "bc186f4e-7725-4b98-abfa-cee5eed947f9",
"metadata": {},
"source": [
"We see that the gradients focus a lot on the neck and the tail, but also on the top corners of the image. Altough it may describe how the neural network take its decision, it may not match the human decision process to classify a duck."
]
},
{
"cell_type": "markdown",
"id": "6d0bffdf-27c7-4930-8ac7-c94cdc1e0c3f",
"metadata": {},
"source": [
"### Saliency maps with SmoothGrads\n",
"Given $x$, [SmoothGrads](https://arxiv.org/abs/1706.03825) aims to compute an average of the gradients in a neighborhood $x^{*}$ to reduce the influence of sharp, local variations. An approximation of this averaged gradient can be computed by the following:\n",
"$$\n",
"\\nabla_{x^{*}}y \\approx \\frac{1}{n}\\sum_0^{n}\\nabla_xf(x+\\mathcal{N}(0,\\sigma))\n",
"$$\n",
"There are two parameters here:\n",
"1. $\\sigma$: the standard deviation of the gaussian sampling\n",
"2. $n$: the number of samples computed by smoothgrad \n",
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"Experiment by changing those parameters and calling the `attribute` method (it may take long if you increase the number of samples: start by increments of 5).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b0462e3-bb46-49bf-aced-20569e63468b",
"metadata": {},
"outputs": [],
"source": [
"n_samples = 50\n",
"sigma = 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c271261-78ee-4e1e-97b1-c43c0500e085",
"metadata": {},
"outputs": [],
"source": [
"saliency = captum.attr.Saliency(modelPostHoc)\n",
"nt = captum.attr.NoiseTunnel(saliency)\n",
"attrs = nt.attribute(inputs=input, target=0, nt_type='smoothgrad_sq', nt_samples=n_samples, stdevs=sigma)\n",
"attrs= np.transpose(attrs.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
"_ = viz.visualize_image_attr(attrs, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
" outlier_perc=10, show_colorbar=True, \n",
" title=\"Overlayed Gradient Magnitudes \\n with SmoothGrad Squared\")"
"id": "f504bb5e-6a3a-4a64-aec8-fe62b38e16aa",
"With averaged gradients, the interpretation seems much less noisy. With a sufficiently high number of samples and a low standard deviation, the gradient seems to vary a lot around the neck to the tail, with some specks on the corner of the image and the beak."
]
},
{
"cell_type": "markdown",
"id": "121f7611-0f58-4b9c-b895-f8ca7142944a",
"metadata": {},
"source": [
"### Integrated gradients\n",
"\n",
"The previous approaches have limitations. Namely, they exist some situations where the gradient of different values is the same. \n",
"To tackle this issue, [integrated Gradients](https://arxiv.org/abs/1703.01365) computes a linear approximation of the gradient on the line between an baseline image $x^{'}$ and the image $x$.\n",
"\n",
"$$\n",
"IG_i = (x_i - x^{'}_i) \\int_{\\alpha=0}^{1} \\nabla_{x_i}\n",
"f(x^{'}+\\alpha(x-x^{'}))d\\alpha\n",
"$$"
"execution_count": null,
"id": "23b17d92-3630-454f-accd-6dd33435faa1",
"metadata": {},
"source": [
"ig = captum.attr.IntegratedGradients(modelPostHoc)\n",
"attributions, delta = ig.attribute(inputs=input, baselines=input*0, target=0, return_convergence_delta=True)\n",
"attributions = np.transpose(attributions.squeeze().cpu().detach().numpy(), (1, 2, 0))\n",
"print('Approximation delta: ', abs(delta))\n",
"_ = viz.visualize_image_attr(attributions, original_image, method=\"blended_heat_map\",sign=\"absolute_value\",\n",
" show_colorbar=True, title=\"Overlayed Integrated Gradients\")"
]
},
{
"cell_type": "markdown",
"id": "1c3dac7c-cfe8-4a36-9ff1-71757ae99cdc",
"metadata": {},
"source": [
"The Integrated Gradients display how much variations (in term of gradient) exist between a white image and the actual image.\n",
"We will now combine Integrated Gradients with SmoothGrads."
"execution_count": null,
"id": "3011d1d3-8fbd-4a22-b6df-4fafdbc9f563",
"metadata": {},
"outputs": [],
"source": [
"n_samples = 50\n",
"sigma = 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bef64173-1077-4ded-8871-055ea14b628f",
"metadata": {},
"source": [
"ig = captum.attr.IntegratedGradients(modelPostHoc)\n",
"nt = captum.attr.NoiseTunnel(ig)\n",
"attributions_smoothgrad = nt.attribute(inputs=input, baselines=input * 0, target=1, nt_type='smoothgrad_sq', nt_samples=n_samples, stdevs=sigma)\n",
"attributions_smoothgrad = np.transpose(attributions_smoothgrad.squeeze(0).cpu().detach().numpy(), (1, 2, 0))\n",
"_ = viz.visualize_image_attr(attributions_smoothgrad, original_image, method=\"blended_heat_map\", sign=\"absolute_value\", \n",
" outlier_perc=10, show_colorbar=True, \n",
" title=\"Overlayed Integrated Gradients \\n with SmoothGrad Squared\")"
]
},
{
"cell_type": "markdown",
"id": "a01eae6b-e595-4d4d-bb2c-10c861f8a6fe",
"metadata": {},
"source": [
"We note that integrated gradients with smoothgrads provide much more focused variations.\n",
"\n",
"Overall, we note that with these three approaches we obtain seemingly similar results. But the following questions remain:\n",
"\n",
"* why an explanation method chose this particular zone of the image\n",
"* how can we state that one explanation method is more representative of the network behaviour than the other\n",
"* what do we do of the explanations?"
]
},
{
"cell_type": "markdown",
"id": "c2cb7053-d9f3-4d55-9ed0-4fc3f54b94ad",
"metadata": {},
"source": [
"## Explainable by design: ProtoTree with the CaBRNet library\n",
"\n",
"\n",
"We will now look at another class of interpretability models: interpretable by-design models. We will focus on ProtoTrees. \n",
"We will study the [ProtoTree](https://arxiv.org/abs/2012.02046) architecture. \n",
"\n",
"\n",
"Some discussion about ProtoTree, namely the parameters we will consider:\n",
"* tree depth\n",
"* the effect of pruning\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "f6ee56d4-648a-4398-84ef-d16564f0dc28",
"metadata": {},
"source": [
"### Preliminary\n",
"\n",
"We downloaded the model and the corresponding generated prototypes. For this session, we also provided pre-made configuration files.\n",
"First, instanciate the model and the config files."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de765e35-4708-47da-9a8f-de4d4c77593e",
"metadata": {},
"source": [
"# Instanciation of paths \n",
"from cabrnet.generic.model import ProtoClassifier\n",
"root_cabrnet_config=os.path.join(\"models\",\"cabrnet\",\"cabrnet\")\n",
"root_model=os.path.join(root_cabrnet_config,\"model\")\n",
"root_protos=os.path.join(root_cabrnet_config,\"prototypes\")\n",
"root_out=os.path.join(\"outs\")\n",
"\n",
"# Configuration files, we change one faulty line\n",
"path_to_model_config=os.path.join(root_model,\"model.yml\")\n",
"path_to_visu_config=os.path.join(root_protos,\"visualization.yml\")\n",
"\n",
"path_to_state_dict=os.path.join(root_model,\"model_state.pth\")\n",
"\n",
"img_path =os.path.join(\"data\",\"cub_train_tiny\",\"001.Black_footed_Albatross\",\"Black_Footed_Albatross_0051_796103.jpg\")\n",
"\n",
"model = ProtoClassifier.build_from_config(config_file=path_to_model_config,state_dict_path=path_to_state_dict)"
]
},
{
"cell_type": "markdown",
"id": "75f8e4b5-7a84-4039-a7ce-22844c828bd4",
"metadata": {},
"source": [
"We loaded a pretrained ProtoTree using CaBRNet, as well as two configuration files. Let us look at `model.yml`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fb1cf1d-d883-44ba-95af-5a097d7c6667",
"metadata": {},
"source": [
"!cat $path_to_model_config"
]
},
{
"cell_type": "markdown",
"id": "47c1808b-ef55-414c-991f-1d4104f01228",
"metadata": {},
"source": [
"This file defines the architecture of a ProtoTree. Consider the _classifier_ section. Among several parameters, we define `depth`: it is the depth of the soft decision tree used in ProtoTree. The higher this parameter, the deeper the tree will be (and thus higher the number of prototypes). Here, 9 was chosen after cross-validation on this dataset. We will examine the influence of changing the depth on another model.\n",
"Note that we did not put anything under the \"weights\" section, as we are loading an already pretrained model through the `model_state.pth`."
]
},
{
"cell_type": "markdown",
"id": "8c4b5e92-fb01-4ebc-8c52-808356d40d0b",
"metadata": {},
"source": [
"#### Evaluate the ProtoTree performance\n",
"\n",
"The snippet of code below calls the CaBRNet `evaluate` method on the model to perform a basic inference and collect some stats. This should take less than one minute."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01ca872f-5156-4373-af90-10cd5ed07e54",
"metadata": {},
"source": [
"stats = model.evaluate(dataloader=loader, device='cpu', progress_bar_position=0)\n",
"for name, value in stats.items():\n",
" print(f\"{name}: {value:.3f}\")\n"
"id": "b373602f-9743-4975-857a-1f920ce8b1a3",
"metadata": {},
"source": [
"The accuracy should be above $0.98$. For this test set, the ProtoTree has a similar performance compared to a classical model. It brings the additionnal benefit of being interpretable, as we will see now. "
]
},
{
"cell_type": "markdown",
"id": "433d32db-a9a3-48a6-bb86-88bdb3adb33d",
"metadata": {},
"source": [
"#### Explain local\n",
"\n",
"We will first examine the inference pipeline of a ProtoTree. We will need\n",
"* a specific image with the same preprocessing used during the ProtoTree's training\n",
"* a model\n",
"* a way to visualize the similarity computed at each node\n",
"\n",
"We have a pre-configured configuration file visualizer under `path_to_visu_config`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9962fbbc-1837-4841-b092-c02cd2593024",
"metadata": {
"scrolled": true
},
"!cat $path_to_visu_config"
{
"cell_type": "markdown",
"id": "19c12bec-fa0a-47b0-90be-4afdeaf27b97",
"metadata": {},
"source": [
"<div style=\"color:red\"> TODO: explain briefly all parameters and provide a configuration function to change the test_patch viz </div>\n"
]
},
"execution_count": null,
"id": "f728b98e-ccc7-46cc-8edf-d34be6181147",
"metadata": {},
"source": [
"from cabrnet.generic.model import SimilarityVisualizer\n",
"!rm -rf $root_out/test_patches # removing existing folder\n",
"visualizer = SimilarityVisualizer.build_from_config(config_file=path_to_visu_config,target=\"test_patch\")\n",
"model.explain(prototype_dir_path=root_protos,output_dir_path=root_out,img_path=img_path,preprocess=transform,device=\"cpu\",visualizer=visualizer)\n",
"imgs = [Image(filename=os.path.join(root_out,\"test_patches\",i)) for i in os.listdir(os.path.join(root_out,\"test_patches\"))]\n",
"display(*imgs)"
]
},
{
"cell_type": "markdown",
"id": "9987db33-9f1d-483b-b8be-5ef47e63e196",
"metadata": {},
"source": [
"#### Explain global\n",
"\n",
"Given extracted prototypes, provide the inference of a ProtoTree"
]
},
{
"cell_type": "code",
"id": "5e91c919-9f4c-407b-a2db-733e297abd4f",
"metadata": {},
"source": [
"model = ProtoClassifier.build_from_config(config_file=path_to_model_config,state_dict_path=path_to_state_dict)\n",
"model.explain_global(prototype_dir_path=root_protos,output_dir_path=root_out)\n"
]
},
{
"cell_type": "code",
"id": "143ac5a9-e328-45b3-8c70-50ecef8794c8",
"metadata": {},
"source": [
"IFrame(os.path.join(root_out,\"global_explanation.pdf\"), width=800, height=200)"
]
},
{
"cell_type": "markdown",
"id": "1d360db4-17dc-422e-af45-3f12334227d3",
"metadata": {},
"source": [
"### On the effect of pruning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c166b54-591b-4132-9887-95500c9904a1",
"metadata": {},
"outputs": [],
"source": [
"# TODO: \n",
"# * load the model with no pruning\n",
"# * use protolib.model.prune() with several threshold\n",
"# * provide global explanation with several thresholds "
]
},
{
"cell_type": "markdown",
"id": "c7236bda-0073-45ec-bed9-437b1fa2b19f",
"metadata": {},
"source": [
"\n",
"<div style=\"color:red\"> TODO: by-design models are a bit more cumbersome to train and use, but they provide an easier to grasp decision process </div>\n"
]
},
{
"cell_type": "markdown",
"id": "22304e49-a9d8-46aa-97c0-2853541a44b1",
"metadata": {},
"source": [
"### "
]
},
{
"cell_type": "markdown",
"id": "59aa9578-1d35-485c-9a76-4238d6f508cc",
"metadata": {},
"source": [
"### "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba630aea-1636-48c4-85ec-3a652b0aed2c",
"metadata": {},
"outputs": [],
"source": []
"display_name": "setixaitp",
"name": "setixaitp"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}