Skip to content

Commit

Permalink
Merge pull request #591 from vivekmig/v0.3.1
Browse files Browse the repository at this point in the history
Captum v0.3.1
  • Loading branch information
vivekmig authored Jan 21, 2021
2 parents 26cad25 + d0920f6 commit ea1461b
Show file tree
Hide file tree
Showing 67 changed files with 2,949 additions and 950 deletions.
4 changes: 3 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ commands:
steps:
- run:
name: "Simple PIP install"
command: python -m pip install -e .[dev]
command: |
python -m pip install --upgrade pip
python -m pip install -e .[dev]
py_3_7_setup:
description: "Set python version to 3.7 and install pip and pytest"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,4 @@ website/static/js/*
!website/static/js/code_block_buttons.js
website/static/_sphinx-sources/
node_modules
captum/insights/attr_vis/widget/static
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![CircleCI](https://circleci.com/gh/pytorch/captum.svg?style=shield)](https://circleci.com/gh/pytorch/captum)

Captum is a model interpretability and understanding library for PyTorch.
Captum means comprehension in latin and contains general purpose implementations
Captum means comprehension in Latin and contains general purpose implementations
of integrated gradients, saliency maps, smoothgrad, vargrad and others for
PyTorch models. It has quick integration for models built with domain-specific
libraries such as torchvision, torchtext, and others.
Expand Down Expand Up @@ -175,12 +175,12 @@ Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
The algorithm outputs an attribution score for each input element and a
convergence delta. The lower the absolute value of the convergence delta the better
is the approximation. If we choose not to return delta,
we can simply not provide `return_convergence_delta` input
we can simply not provide the `return_convergence_delta` input
argument. The absolute value of the returned deltas can be interpreted as an
approximation error for each input sample.
It can also serve as a proxy of how accurate the integral approximation for given
inputs and baselines is.
If the approximation error is large, we can try larger number of integral
If the approximation error is large, we can try a larger number of integral
approximation steps by setting `n_steps` to a larger value. Not all algorithms
return approximation error. Those which do, though, compute it based on the
completeness property of the algorithms.
Expand Down Expand Up @@ -224,7 +224,7 @@ in order to get per example average delta.


Below is an example of how we can apply `DeepLift` and `DeepLiftShap` on the
`ToyModel` described above. Current implementation of DeepLift supports only
`ToyModel` described above. The current implementation of DeepLift supports only the
`Rescale` rule.
For more details on alternative implementations, please see the [DeepLift paper](https://arxiv.org/abs/1704.02685).

Expand Down Expand Up @@ -286,7 +286,7 @@ In order to smooth and improve the quality of the attributions we can run
to smoothen the attributions by aggregating them for multiple noisy
samples that were generated by adding gaussian noise.

Here is an example how we can use `NoiseTunnel` with `IntegratedGradients`.
Here is an example of how we can use `NoiseTunnel` with `IntegratedGradients`.

```python
ig = IntegratedGradients(model)
Expand Down Expand Up @@ -338,7 +338,7 @@ It is an extension of path integrated gradients for hidden layers and holds the
completeness property as well.

It doesn't attribute the contribution scores to the input features
but shows the importance of each neuron in selected layer.
but shows the importance of each neuron in the selected layer.
```python
lc = LayerConductance(model, model.lin1)
attributions, delta = lc.attribute(input, baselines=baseline, target=0, return_convergence_delta=True)
Expand Down Expand Up @@ -412,6 +412,8 @@ See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out.
## Talks and Papers
The slides of our presentation from NeurIPS 2019 can be found [here](docs/presentations/Captum_NeurIPS_2019_final.key)

The slides of our presentation from KDD 2020 tutorial can be found [here](https://pytorch-tutorial-assets.s3.amazonaws.com/Captum_KDD_2020.pdf)

## References of Algorithms

* `IntegratedGradients`, `LayerIntegratedGradients`: [Axiomatic Attribution for Deep Networks, Mukund Sundararajan et al. 2017](https://arxiv.org/abs/1703.01365) and [Did the Model Understand the Question?, Pramod K. Mudrakarta, et al. 2018](https://arxiv.org/abs/1805.05492)
Expand Down
2 changes: 1 addition & 1 deletion captum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env python3

__version__ = "0.3.0"
__version__ = "0.3.1"
63 changes: 63 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,20 @@ def _expand_target(
return target


def _expand_feature_mask(
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
):
is_feature_mask_tuple = _is_tuple(feature_mask)
feature_mask = _format_tensor_into_tuples(feature_mask)
feature_mask_new = tuple(
feature_mask_elem.repeat_interleave(n_samples, dim=0)
if feature_mask_elem.size(0) > 1
else feature_mask_elem
for feature_mask_elem in feature_mask
)
return _format_output(is_feature_mask_tuple, feature_mask_new)


def _expand_and_update_baselines(
inputs: Tuple[Tensor, ...],
n_samples: int,
Expand Down Expand Up @@ -317,6 +331,18 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
kwargs["target"] = target


def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
if "feature_mask" not in kwargs:
return

feature_mask = kwargs["feature_mask"]
if feature_mask is None:
return

feature_mask = _expand_feature_mask(feature_mask, n_samples)
kwargs["feature_mask"] = feature_mask


@typing.overload
def _format_output(
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
Expand Down Expand Up @@ -354,6 +380,43 @@ def _format_output(
return output if is_inputs_tuple else output[0]


@typing.overload
def _format_outputs(
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
) -> Union[Tensor, Tuple[Tensor, ...]]:
...


@typing.overload
def _format_outputs(
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
...


@typing.overload
def _format_outputs(
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
...


def _format_outputs(
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
assert isinstance(outputs, list), "Outputs must be a list"
assert is_multiple_inputs or len(outputs) == 1, (
"outputs should contain multiple inputs or have a single output"
f"however the number of outputs is: {len(outputs)}"
)

return (
[_format_output(len(output) > 1, output) for output in outputs]
if is_multiple_inputs
else _format_output(len(outputs[0]) > 1, outputs[0])
)


def _run_forward(
forward_func: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down
21 changes: 16 additions & 5 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -286,10 +287,11 @@ def sklearn_train_linear_model(
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

assert (
sklearn.__version__ >= "0.23.0"
), "Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
)

num_batches = 0
xs, ys, ws = [], [], []
Expand Down Expand Up @@ -323,7 +325,16 @@ def sklearn_train_linear_model(
sklearn_model = reduce(
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
)(**construct_kwargs)
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
try:
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
except TypeError:
sklearn_model.fit(x, y, **fit_kwargs)
warnings.warn(
"Sample weight is not supported for the provided linear model!"
" Trained model without weighting inputs. For Lasso, please"
" upgrade sklearn to a version >= 0.23.0."
)

t2 = time.time()

# Convert weights to pytorch
Expand Down
1 change: 1 addition & 0 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def attribute(
eval_diff = (
initial_eval - modified_eval.reshape((-1, num_outputs))
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
eval_diff = eval_diff.to(total_attrib[i].device)
if self.use_weights:
weights[i] += current_mask.float().sum(dim=0)
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_core/gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def attribute(
nt, # self
inputs,
nt_type="smoothgrad",
n_samples=n_samples,
nt_samples=n_samples,
stdevs=stdevs,
draw_baseline_from_distrib=True,
baselines=baselines,
Expand Down
12 changes: 7 additions & 5 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import Lime
from captum.attr._utils.common import lime_n_perturb_samples_deprecation_decorator
from captum.log import log_usage


Expand Down Expand Up @@ -72,14 +73,15 @@ def __init__(self, forward_func: Callable) -> None:
)

@log_usage()
@lime_n_perturb_samples_deprecation_decorator
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_perturb_samples: int = 25,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -213,9 +215,9 @@ def attribute( # type: ignore
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_perturb_samples (int, optional): The number of samples of the original
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_perturb_samples` is not provided.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
Expand Down Expand Up @@ -266,7 +268,7 @@ def attribute( # type: ignore
>>> ks = KernelShap(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = ks.attribute(input, target=1, n_perturb_samples=200)
>>> attr = ks.attribute(input, target=1, n_samples=200)
>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
Expand Down Expand Up @@ -299,7 +301,7 @@ def attribute( # type: ignore
target=target,
additional_forward_args=additional_forward_args,
feature_mask=feature_mask,
n_perturb_samples=n_perturb_samples,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
return_input_shape=return_input_shape,
)
2 changes: 1 addition & 1 deletion captum/attr/_core/layer/layer_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def attribute(
nt, # self
inputs,
nt_type="smoothgrad",
n_samples=n_samples,
nt_samples=n_samples,
stdevs=stdevs,
draw_baseline_from_distrib=True,
baselines=baselines,
Expand Down
Loading

0 comments on commit ea1461b

Please sign in to comment.