Skip to content

Commit

Permalink
Merge pull request #22 from understandable-machine-intelligence-lab/f…
Browse files Browse the repository at this point in the history
…ixes

Fixes
  • Loading branch information
annahedstroem authored Oct 12, 2021
2 parents 1a1a64e + 0a28aa6 commit a4d9b07
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 3,518 deletions.
125 changes: 84 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</p>

--------------
<!--<img src="quantus.png" alt="drawing" width="200"/>-->
<!--<img src="quantus.png" alt="drawing" width="600"/>-->

<!--**A library that helps you understand your XAI explanations..**-->
<!--
Expand Down Expand Up @@ -44,7 +44,7 @@ The library contains implementations of the following evaluation metrics:
* **[Top-K Intersection](https://arxiv.org/abs/2104.14995) (Theiner et al., 2021)**: computes the intersection between a ground truth mask and the binarized explanation at the top k feature locations
* **[Relevance Rank Accuracy](https://arxiv.org/abs/2003.07258) (Arras et al., 2021)**: measures the ratio of highly attributed pixels within a ground-truth mask towards the size of the ground truth mask
* **[Relevance Mass Accuracy](https://arxiv.org/abs/2003.07258) (Arras et al., 2021)**: measures the ratio of positively attributed attributions inside the ground-truth mask towards the overall positive attributions
* **[AUC](https://doi.org/10.1016/j.patrec.2005.10.010) (Arras et al., 2021)**: compares the ranking between attributions and a given ground-truth mask
* **[AUC](https://doi.org/10.1016/j.patrec.2005.10.010) (Fawcett et al., 206)**: compares the ranking between attributions and a given ground-truth mask
* *Complexity:*
* **[Sparseness](https://arxiv.org/abs/1810.06583) (Chalasani et al., 2020)**: uses the Gini Index for measuring, if only highly attributed features are truly predictive of the model output
* **[Complexity](https://arxiv.org/abs/2005.00631) (Bhatt et al., 2020)**: computes the entropy of the fractional contribution of all features to the total magnitude of the attribution individually
Expand Down Expand Up @@ -94,57 +94,95 @@ import quantus
import torch
import torchvision

# Load a pre-trained classification model.
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
model = LeNet()
model.load_state_dict(torch.load("tutorials/assets/mnist"))

# Load datasets and make loaders.
test_set = torchvision.datasets.Caltech256(root='./sample_data',
download=True,
transform=torchvision.transforms.Compose([torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
test_loader = torch.utils.data.DataLoader(test_set, batch_size=12)
test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for evaluation.
x_batch, y_batch = iter(test_loader).next()
x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Enable GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
```

Next, we generate some explanations for some test set samples that we wish to evaluate using `Quantus` library.
Next, we generate some explanations for some test set samples that we wish to evaluate using quantus library.

```python
import captum
from captum.attr import Saliency, IntegratedGradients

a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(inputs)).sum(axis=1)
# Generate Integrated Gradients attributions of the first batch of the test set.
a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1).cpu().numpy()
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(x_batch)).sum(axis=1).cpu().numpy()

# Save x_batch and y_batch as numpy arrays that will be used to call metric instances.
x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Quick assert.
assert [isinstance(obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_saliency, a_batch_intgrad]]

# You can use any function e.g., quantus.explain (not necessarily captum) to generate your explanations.
```
To evaluate explanations, there are two options.
<p align="center">
<img src="tutorials/assets/mnist_example.png" alt="drawing" width="450"/>
</p>

1) Either evaluate your explanations in a one-liner - by calling the instance of the metric class.
The qualitative aspects of the Saliency and Integrated Gradients explanations may look fairly uninterpretable - since we lack ground truth of what the explanations should be looking like, it is hard to draw conclusions about the explainable evidence that we see.

````python
To quantitatively evaluate the explanation we can use apply Quantus. As a starter we may be interested in measuring how sensitive the explanations are to very slight perturbations. For this, we apply max-sensitivity by Yeh et al., 2019 to evaluate our explanations. With Quantus, there are two options.

metric_sensitivity = quantus.MaxSensitivity()
scores = metric_sensitivity(model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch_saliency,
**{"explain_func": quantus.explain, "device": device, "img_size": 224, "normalize": True})
````
1) Either evaluate the explanations in a one-liner - by calling the instance of the metric class.

```python
# Return max sensitivity scores in an one-liner - by calling the metric instance.
scores_saliency = quantus.MaxSensitivity(**{
"nr_samples": 10,
"perturb_radius": 0.1,
"norm_numerator": quantus.fro_norm,
"norm_denominator": quantus.fro_norm,
"perturb_func": quantus.uniform_sampling,
"similarity_func": quantus.difference,
})(model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch_saliency,
**{"explain_func": quantus.explain, "method": "Saliency", "device": device,
"img_size": 28, "normalise": False, "abs": False})
```

We also score the Integrated Gradient explanations.
```python
scores_intgrad = quantus.MaxSensitivity(**{
"nr_samples": 10,
"perturb_radius": 0.1,
"norm_numerator": quantus.fro_norm,
"norm_denominator": quantus.fro_norm,
"perturb_func": quantus.uniform_sampling,
"similarity_func": quantus.difference,
})(model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch_intgrad,
**{"explain_func": quantus.explain, "method": "IntegratedGradients",
"device": device, "img_size": 28, "normalise": False, "abs": False})
```

2) Or use `quantus.evaluate()` which is a high-level function that allow you to evaluate multiple XAI methods on several metrics at once.

```python
import numpy as np

metrics = {"Faithfulness correlation": quantus.FaithfulnessCorrelation(**{"subset_size": 32}),
"max-Sensitivity": quantus.MaxSensitivity()}
metrics = {"max-Sensitivity": quantus.MaxSensitivity(**{"nr_samples": 10,
"perturb_radius": 0.1,
"norm_numerator": quantus.fro_norm,
"norm_denominator": quantus.fro_norm,
"perturb_func": quantus.uniform_sampling,
"similarity_func": quantus.difference})}

xai_methods = {"Saliency": a_batch_saliency,
"IntegratedGradients": a_batch_intgrad}
Expand All @@ -155,14 +193,28 @@ results = quantus.evaluate(evaluation_metrics=metrics,
x_batch=x_batch,
y_batch=y_batch,
agg_func=np.mean,
**{"device": device, "img_size": 224, "normalize": True})

# Summarise in a dataframe.
**{"explain_func": quantus.explain, "device": device,
"img_size": 28, "normalise": False, "abs": False})
# Summarise results in a dataframe.
df = pd.DataFrame(results)
df
```

Other miscellaneous functionality of `Quantus` library.
As result, the max-Sensitivity scores for Saliency = 0.41 (0.15) and Integrated Gradients = 0.17 (0.05). Lower scores are considered better, which means that in this experimental setting,
Integrated Gradients can be considered more robust than Saliency explanations. To replicate this example please find notebook under `/tutorials/getting_started.ipynb`.

### Other examples

More examples are located in the `/tutorials` folder. For example,

* Compare explanation methods on different evaluation criteria (`/tutorials/basic_example_all_metrics.ipynb`)
* Measure sensitivity of hyperparameter choice (`/tutorials/sensitivity_parameterisation.ipynb`)
* Understand how sensitivity of explanations change when a model is learning (`/tutorials/model_training_explanation_robustness.ipynb`)
<!--* Investigate to what extent metrics belonging to the same category score explanations similarly (check out: `/tutorials/category_reliability.ipynb`)-->

... and more.

Other miscellaneous functionality of Quantus library.

````python
# Interpret scores.
Expand All @@ -171,19 +223,10 @@ sensitivity_scorer.interpret_scores
# Understand what hyperparameters to tune.
sensitivity_scorer.get_params

# To list available metrics
# To list available metrics.
quantus.available_metrics
````

See more examples and use cases in the `/tutorials` folder. For example

* Compare explanation methods on different evaluation criteria (check out: `/tutorials/basic_example_all_metrics.ipynb`)
* Measure sensitivity of hyperparameter choice (check out: `/tutorials/hyperparameter_sensitivity.ipynb`)
* Understand how sensitivity of explanations change when a model is learning (check out: `/tutorials/model_training_explanation_sensitvitiy.ipynb`)
<!--* Investigate to what extent metrics belonging to the same category score explanations similarly (check out: `/tutorials/category_reliability.ipynb`)-->

... and more.

<!--
## Feature list
Expand Down
2 changes: 1 addition & 1 deletion quantus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .helpers import *
from .metrics import *
from .methods import *
from .methods import *
1 change: 1 addition & 0 deletions quantus/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .similar_func import *
from .explanation_func import *
from .warn_func import *
from .models import *
28 changes: 28 additions & 0 deletions quantus/helpers/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch


class LeNet(torch.nn.Module):
"""Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch."""
def __init__(self):
super().__init__()
self.conv_1 = torch.nn.Conv2d(1, 6, 5)
self.pool_1 = torch.nn.MaxPool2d(2, 2)
self.relu_1 = torch.nn.ReLU()
self.conv_2 = torch.nn.Conv2d(6, 16, 5)
self.pool_2 = torch.nn.MaxPool2d(2, 2)
self.relu_2 = torch.nn.ReLU()
self.fc_1 = torch.nn.Linear(256, 120)
self.relu_3 = torch.nn.ReLU()
self.fc_2 = torch.nn.Linear(120, 84)
self.relu_4 = torch.nn.ReLU()
self.fc_3 = torch.nn.Linear(84, 10)

def forward(self, x):
x = self.pool_1(self.relu_1(self.conv_1(x)))
x = self.pool_2(self.relu_2(self.conv_2(x)))
x = x.view(x.shape[0], -1)
x = self.relu_3(self.fc_1(x))
x = self.relu_4(self.fc_2(x))
x = self.fc_3(x)
return x

Binary file added tutorials/assets/mnist
Binary file not shown.
Binary file added tutorials/assets/mnist_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a4d9b07

Please sign in to comment.