forked from pytorch/captum
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add input wrapper for layer methods (pytorch#534)
Summary: Pull Request resolved: pytorch#534 Introduces a utility class called `ModelInputWrapper` to wrap over a model in order to treat inputs as separate layers. This does so by mapping each input fed to `forward` using an `Identity` operation. This way if attribute_to_inputs=True or False it should work. Add two tests: - Test whether _foward_layer_eval retrieves the appropriate input values - Compare regular IG with layer IG and layer wrapped inputs Updated tutorial and documentation Differential Revision: D25110896 fbshipit-source-id: 1d951bf06510aab85f83d4484ed8fabb724356c5
- Loading branch information
1 parent
8f95995
commit 4052a19
Showing
6 changed files
with
386 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import inspect | ||
from typing import Any | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class InputIdentity(nn.Module): | ||
def __init__(self, input_name: str) -> None: | ||
r""" | ||
The identity operation | ||
Args: | ||
input_name (str) | ||
The name of the input this layer is associated to. For debugging | ||
purposes. | ||
""" | ||
super().__init__() | ||
self.input_name = input_name | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class ModelInputWrapper(nn.Module): | ||
def __init__(self, module_to_wrap: nn.Module) -> None: | ||
r""" | ||
This is a convenience class. This wraps a model via first feeding the | ||
model's inputs to separate layers (one for each input) and then feeding | ||
the (unmodified) inputs to the underlying model (`module_to_wrap`). Each | ||
input is fed through an `InputIdentity` layer/module. This class does | ||
not change how you feed inputs to your model, so feel free to use your | ||
model as you normally would. | ||
To access a wrapped input layer, simply access it via the `input_maps` | ||
ModuleDict, e.g. to get the corresponding module for input "x", simply | ||
provide/write `my_wrapped_module.input_maps["x"]` | ||
This is done such that one can use layer attribution methods on inputs. | ||
Which should allow you to use mix layers with inputs with these | ||
attribution methods. This is especially useful multimodal models which | ||
input discrete features (mapped to embeddings, such as text) and regular | ||
continuous feature vectors. | ||
Notes: | ||
- Since inputs are mapped with the identity, attributing to the | ||
input/feature can be done with either the input or output of the | ||
layer, e.g. attributing to an input/feature doesn't depend on whether | ||
attribute_to_layer_input is True or False for | ||
LayerIntegratedGradients. | ||
- Please refer to the multimodal tutorial or unit tests | ||
(test/attr/test_layer_wrapper.py) for an example. | ||
Args: | ||
module_to_wrap (nn.Module): | ||
The model/module you want to wrap | ||
""" | ||
super().__init__() | ||
self.module = module_to_wrap | ||
|
||
# ignore self | ||
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:] | ||
self.input_maps = nn.ModuleDict( | ||
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} | ||
) | ||
|
||
def forward(self, *args, **kwargs) -> Any: | ||
args = list(args) | ||
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)): | ||
args[idx] = self.input_maps[arg_name](arg) | ||
|
||
for arg_name in kwargs.keys(): | ||
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name]) | ||
|
||
return self.module(*tuple(args), **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import functools | ||
import inspect | ||
from typing import Callable, Dict, Tuple | ||
|
||
import torch | ||
|
||
from captum._utils.gradient import _forward_layer_eval | ||
from captum.attr import ( | ||
DeepLift, | ||
DeepLiftShap, | ||
FeatureAblation, | ||
GradientShap, | ||
InputXGradient, | ||
IntegratedGradients, | ||
LayerDeepLift, | ||
LayerDeepLiftShap, | ||
LayerFeatureAblation, | ||
LayerGradientShap, | ||
LayerGradientXActivation, | ||
LayerIntegratedGradients, | ||
) | ||
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper | ||
from tests.helpers.basic import BaseTest, assertTensorTuplesAlmostEqual | ||
from tests.helpers.basic_models import ( | ||
BasicModel, | ||
BasicModel_MultiLayer_TrueMultiInput, | ||
MixedKwargsAndArgsModule, | ||
) | ||
|
||
layer_methods_to_test_with_equiv = [ | ||
# layer_method, equiv_method, whether or not to use multiple layers | ||
(LayerIntegratedGradients, IntegratedGradients, [True, False]), | ||
(LayerGradientXActivation, InputXGradient, [True, False]), | ||
(LayerFeatureAblation, FeatureAblation, [False]), | ||
(LayerDeepLift, DeepLift, [False]), | ||
(LayerDeepLiftShap, DeepLiftShap, [False]), | ||
(LayerGradientShap, GradientShap, [False]), | ||
# TODO: add other algorithms here | ||
] | ||
|
||
|
||
class InputLayerMeta(type): | ||
def __new__(cls, name: str, bases: Tuple, attrs: Dict): | ||
for ( | ||
layer_method, | ||
equiv_method, | ||
multi_layers, | ||
) in layer_methods_to_test_with_equiv: | ||
for multi_layer in multi_layers: | ||
test_name = ( | ||
f"test_{layer_method.__name__}" | ||
+ f"_{equiv_method.__name__}_{multi_layer}" | ||
) | ||
attrs[ | ||
test_name | ||
] = lambda self: self.layer_method_with_input_layer_patches( | ||
layer_method, equiv_method, multi_layer | ||
) | ||
|
||
return super(InputLayerMeta, cls).__new__(cls, name, bases, attrs) | ||
|
||
|
||
class TestInputLayerWrapper(BaseTest, metaclass=InputLayerMeta): | ||
def test_forward_layer_eval_on_mixed_args_kwargs_module(self) -> None: | ||
x = torch.randn(10, 5) | ||
y = torch.randn(10, 5) | ||
|
||
model = MixedKwargsAndArgsModule() | ||
|
||
self.forward_eval_layer_with_inputs_helper(model, {"x": x}) | ||
self.forward_eval_layer_with_inputs_helper(model, {"x": x, "y": y}) | ||
|
||
def layer_method_with_input_layer_patches( | ||
self, | ||
layer_method_class: Callable, | ||
equiv_method_class: Callable, | ||
multi_layer: bool, | ||
) -> None: | ||
model = BasicModel_MultiLayer_TrueMultiInput() if multi_layer else BasicModel() | ||
|
||
input_names = ["x1", "x2", "x3", "x4"] if multi_layer else ["input"] | ||
model = ModelInputWrapper(model) | ||
|
||
layers = [model.input_maps[inp] for inp in input_names] | ||
layer_method = layer_method_class( | ||
model, layer=layers if multi_layer else layers[0] | ||
) | ||
equivalent_method = equiv_method_class(model) | ||
|
||
inputs = tuple(torch.rand(5, 3) for _ in input_names) | ||
baseline = tuple(torch.zeros(5, 3) for _ in input_names) | ||
|
||
args = inspect.getfullargspec(equivalent_method.attribute.__wrapped__).args | ||
|
||
args_to_use = [inputs] | ||
if "baselines" in args: | ||
args_to_use += [baseline] | ||
|
||
a1 = layer_method.attribute(*args_to_use, target=0) | ||
a2 = layer_method.attribute( | ||
*args_to_use, target=0, attribute_to_layer_input=True | ||
) | ||
|
||
real_attributions = equivalent_method.attribute(*args_to_use, target=0) | ||
|
||
if not isinstance(a1, tuple): | ||
a1 = (a1,) | ||
a2 = (a2,) | ||
|
||
if not isinstance(real_attributions, tuple): | ||
real_attributions = (real_attributions,) | ||
|
||
assertTensorTuplesAlmostEqual(self, a1, a2) | ||
assertTensorTuplesAlmostEqual(self, a1, real_attributions) | ||
|
||
def forward_eval_layer_with_inputs_helper(self, model, inputs_to_test): | ||
# hard coding for simplicity | ||
# 0 if using args, 1 if using kwargs | ||
# => no 0s after first 1 (left to right) | ||
# | ||
# used to test utilization of args/kwargs | ||
use_args_or_kwargs = [ | ||
[[0], [1]], | ||
[ | ||
[0, 0], | ||
[0, 1], | ||
[1, 1], | ||
], | ||
] | ||
|
||
model = ModelInputWrapper(model) | ||
|
||
def forward_func(*args, args_or_kwargs=None): | ||
# convert to args or kwargs to test *args and **kwargs wrapping behavior | ||
new_args = [] | ||
new_kwargs = {} | ||
for args_or_kwarg, name, inp in zip( | ||
args_or_kwargs, inputs_to_test.keys(), args | ||
): | ||
if args_or_kwarg: | ||
new_kwargs[name] = inp | ||
else: | ||
new_args.append(inp) | ||
return model(*new_args, **new_kwargs) | ||
|
||
for args_or_kwargs in use_args_or_kwargs[len(inputs_to_test) - 1]: | ||
with self.subTest(args_or_kwargs=args_or_kwargs): | ||
inputs = _forward_layer_eval( | ||
functools.partial(forward_func, args_or_kwargs=args_or_kwargs), | ||
inputs=tuple(inputs_to_test.values()), | ||
layer=[model.input_maps[name] for name in inputs_to_test.keys()], | ||
) | ||
|
||
inputs_with_attrib_to_inp = _forward_layer_eval( | ||
functools.partial(forward_func, args_or_kwargs=args_or_kwargs), | ||
inputs=tuple(inputs_to_test.values()), | ||
layer=[model.input_maps[name] for name in inputs_to_test.keys()], | ||
attribute_to_layer_input=True, | ||
) | ||
|
||
for i1, i2, i3 in zip( | ||
inputs, inputs_with_attrib_to_inp, inputs_to_test.values() | ||
): | ||
self.assertTrue((i1[0] == i2[0]).all()) | ||
self.assertTrue((i1[0] == i3).all()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.