Skip to content

Commit

Permalink
add input wrapper for layer methods (pytorch#534)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#534

Introduces a utility class called `ModelInputWrapper` to wrap over a model in order to treat inputs as separate layers.

This does so by mapping each input fed to `forward` using an `Identity` operation. This way if attribute_to_inputs=True or False it should work.

Add two tests:
- Test whether _foward_layer_eval retrieves the appropriate input values
- Compare regular IG with layer IG and layer wrapped inputs

Updated tutorial and documentation

Differential Revision: D25110896

fbshipit-source-id: 1d951bf06510aab85f83d4484ed8fabb724356c5
  • Loading branch information
miguelmartin75 authored and facebook-github-bot committed Dec 11, 2020
1 parent 8f95995 commit 4052a19
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 200 deletions.
3 changes: 2 additions & 1 deletion captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def flatten_tuple(tup):

if self.device_ids is None:
self.device_ids = getattr(self.forward_func, "device_ids", None)

inputs_layer = _forward_layer_eval(
self.forward_func,
inps,
Expand Down Expand Up @@ -398,7 +399,7 @@ def gradient_func(
target_ind: TargetType = None,
additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
if self.device_ids is None:
if self.device_ids is None or len(self.device_ids) == 0:
scattered_inputs = (inputs,)
else:
# scatter method does not have a precise enough return type in its
Expand Down
76 changes: 76 additions & 0 deletions captum/attr/_utils/input_layer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

import inspect
from typing import Any

import torch.nn as nn


class InputIdentity(nn.Module):
def __init__(self, input_name: str) -> None:
r"""
The identity operation
Args:
input_name (str)
The name of the input this layer is associated to. For debugging
purposes.
"""
super().__init__()
self.input_name = input_name

def forward(self, x):
return x


class ModelInputWrapper(nn.Module):
def __init__(self, module_to_wrap: nn.Module) -> None:
r"""
This is a convenience class. This wraps a model via first feeding the
model's inputs to separate layers (one for each input) and then feeding
the (unmodified) inputs to the underlying model (`module_to_wrap`). Each
input is fed through an `InputIdentity` layer/module. This class does
not change how you feed inputs to your model, so feel free to use your
model as you normally would.
To access a wrapped input layer, simply access it via the `input_maps`
ModuleDict, e.g. to get the corresponding module for input "x", simply
provide/write `my_wrapped_module.input_maps["x"]`
This is done such that one can use layer attribution methods on inputs.
Which should allow you to use mix layers with inputs with these
attribution methods. This is especially useful multimodal models which
input discrete features (mapped to embeddings, such as text) and regular
continuous feature vectors.
Notes:
- Since inputs are mapped with the identity, attributing to the
input/feature can be done with either the input or output of the
layer, e.g. attributing to an input/feature doesn't depend on whether
attribute_to_layer_input is True or False for
LayerIntegratedGradients.
- Please refer to the multimodal tutorial or unit tests
(test/attr/test_layer_wrapper.py) for an example.
Args:
module_to_wrap (nn.Module):
The model/module you want to wrap
"""
super().__init__()
self.module = module_to_wrap

# ignore self
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
self.input_maps = nn.ModuleDict(
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
)

def forward(self, *args, **kwargs) -> Any:
args = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
args[idx] = self.input_maps[arg_name](arg)

for arg_name in kwargs.keys():
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])

return self.module(*tuple(args), **kwargs)
16 changes: 16 additions & 0 deletions tests/attr/helpers/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from captum.attr._core.occlusion import Occlusion
from captum.attr._core.saliency import Saliency
from captum.attr._core.shapley_value import ShapleyValueSampling
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
from tests.helpers.basic import set_all_random_seeds
from tests.helpers.basic_models import (
BasicModel_ConvNet,
Expand Down Expand Up @@ -1160,4 +1161,19 @@
"target": 0,
},
},
{
"name": "basic_layer_ig_multi_layer_multi_output_with_input_wrapper",
"algorithms": [LayerIntegratedGradients],
"model": ModelInputWrapper(BasicModel_MultiLayer_TrueMultiInput()),
"layer": ["module.m1", "module.m234"],
"attribute_args": {
"inputs": (
torch.randn(5, 3),
torch.randn(5, 3),
torch.randn(5, 3),
torch.randn(5, 3),
),
"target": 0,
},
},
]
167 changes: 167 additions & 0 deletions tests/attr/test_input_layer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env python3

import functools
import inspect
from typing import Callable, Dict, Tuple

import torch

from captum._utils.gradient import _forward_layer_eval
from captum.attr import (
DeepLift,
DeepLiftShap,
FeatureAblation,
GradientShap,
InputXGradient,
IntegratedGradients,
LayerDeepLift,
LayerDeepLiftShap,
LayerFeatureAblation,
LayerGradientShap,
LayerGradientXActivation,
LayerIntegratedGradients,
)
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
from tests.helpers.basic import BaseTest, assertTensorTuplesAlmostEqual
from tests.helpers.basic_models import (
BasicModel,
BasicModel_MultiLayer_TrueMultiInput,
MixedKwargsAndArgsModule,
)

layer_methods_to_test_with_equiv = [
# layer_method, equiv_method, whether or not to use multiple layers
(LayerIntegratedGradients, IntegratedGradients, [True, False]),
(LayerGradientXActivation, InputXGradient, [True, False]),
(LayerFeatureAblation, FeatureAblation, [False]),
(LayerDeepLift, DeepLift, [False]),
(LayerDeepLiftShap, DeepLiftShap, [False]),
(LayerGradientShap, GradientShap, [False]),
# TODO: add other algorithms here
]


class InputLayerMeta(type):
def __new__(cls, name: str, bases: Tuple, attrs: Dict):
for (
layer_method,
equiv_method,
multi_layers,
) in layer_methods_to_test_with_equiv:
for multi_layer in multi_layers:
test_name = (
f"test_{layer_method.__name__}"
+ f"_{equiv_method.__name__}_{multi_layer}"
)
attrs[
test_name
] = lambda self: self.layer_method_with_input_layer_patches(
layer_method, equiv_method, multi_layer
)

return super(InputLayerMeta, cls).__new__(cls, name, bases, attrs)


class TestInputLayerWrapper(BaseTest, metaclass=InputLayerMeta):
def test_forward_layer_eval_on_mixed_args_kwargs_module(self) -> None:
x = torch.randn(10, 5)
y = torch.randn(10, 5)

model = MixedKwargsAndArgsModule()

self.forward_eval_layer_with_inputs_helper(model, {"x": x})
self.forward_eval_layer_with_inputs_helper(model, {"x": x, "y": y})

def layer_method_with_input_layer_patches(
self,
layer_method_class: Callable,
equiv_method_class: Callable,
multi_layer: bool,
) -> None:
model = BasicModel_MultiLayer_TrueMultiInput() if multi_layer else BasicModel()

input_names = ["x1", "x2", "x3", "x4"] if multi_layer else ["input"]
model = ModelInputWrapper(model)

layers = [model.input_maps[inp] for inp in input_names]
layer_method = layer_method_class(
model, layer=layers if multi_layer else layers[0]
)
equivalent_method = equiv_method_class(model)

inputs = tuple(torch.rand(5, 3) for _ in input_names)
baseline = tuple(torch.zeros(5, 3) for _ in input_names)

args = inspect.getfullargspec(equivalent_method.attribute.__wrapped__).args

args_to_use = [inputs]
if "baselines" in args:
args_to_use += [baseline]

a1 = layer_method.attribute(*args_to_use, target=0)
a2 = layer_method.attribute(
*args_to_use, target=0, attribute_to_layer_input=True
)

real_attributions = equivalent_method.attribute(*args_to_use, target=0)

if not isinstance(a1, tuple):
a1 = (a1,)
a2 = (a2,)

if not isinstance(real_attributions, tuple):
real_attributions = (real_attributions,)

assertTensorTuplesAlmostEqual(self, a1, a2)
assertTensorTuplesAlmostEqual(self, a1, real_attributions)

def forward_eval_layer_with_inputs_helper(self, model, inputs_to_test):
# hard coding for simplicity
# 0 if using args, 1 if using kwargs
# => no 0s after first 1 (left to right)
#
# used to test utilization of args/kwargs
use_args_or_kwargs = [
[[0], [1]],
[
[0, 0],
[0, 1],
[1, 1],
],
]

model = ModelInputWrapper(model)

def forward_func(*args, args_or_kwargs=None):
# convert to args or kwargs to test *args and **kwargs wrapping behavior
new_args = []
new_kwargs = {}
for args_or_kwarg, name, inp in zip(
args_or_kwargs, inputs_to_test.keys(), args
):
if args_or_kwarg:
new_kwargs[name] = inp
else:
new_args.append(inp)
return model(*new_args, **new_kwargs)

for args_or_kwargs in use_args_or_kwargs[len(inputs_to_test) - 1]:
with self.subTest(args_or_kwargs=args_or_kwargs):
inputs = _forward_layer_eval(
functools.partial(forward_func, args_or_kwargs=args_or_kwargs),
inputs=tuple(inputs_to_test.values()),
layer=[model.input_maps[name] for name in inputs_to_test.keys()],
)

inputs_with_attrib_to_inp = _forward_layer_eval(
functools.partial(forward_func, args_or_kwargs=args_or_kwargs),
inputs=tuple(inputs_to_test.values()),
layer=[model.input_maps[name] for name in inputs_to_test.keys()],
attribute_to_layer_input=True,
)

for i1, i2, i3 in zip(
inputs, inputs_with_attrib_to_inp, inputs_to_test.values()
):
self.assertTrue((i1[0] == i2[0]).all())
self.assertTrue((i1[0] == i3).all())
10 changes: 10 additions & 0 deletions tests/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
"""


class MixedKwargsAndArgsModule(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y=None):
if y is not None:
return x + y
return x


class BasicModel(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
Loading

0 comments on commit 4052a19

Please sign in to comment.