Skip to content

Commit

Permalink
Assert grad_input and inputs in DeepLift and fix the layer attributio…
Browse files Browse the repository at this point in the history
…n issue for MaxPool (pytorch#390)

Summary:
Related to the issue: pytorch#382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook.

Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it.

Added test cases for layer and neuron use cases.
Pull Request resolved: pytorch#390

Reviewed By: edward-io

Differential Revision: D22197030

Pulled By: NarineK

fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
  • Loading branch information
NarineK authored and facebook-github-bot committed Jun 24, 2020
1 parent 7d74ba5 commit be3a2ab
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 4 deletions.
26 changes: 25 additions & 1 deletion captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
forward_hook_with_return: Literal[False] = False,
) -> Tuple[Dict[device, Tuple[Tensor, ...]], Literal[True, False]]:
...
Expand All @@ -168,6 +171,9 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
*,
forward_hook_with_return: Literal[True],
) -> Tuple[Dict[device, Tuple[Tensor, ...]], Tensor, Literal[True, False]]:
Expand All @@ -181,6 +187,9 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
forward_hook_with_return: bool = False,
) -> Union[
Tuple[Dict[device, Tuple[Tensor, ...]], Tensor, bool],
Expand Down Expand Up @@ -208,6 +217,7 @@ def forward_hook(module, inp, out=None):
nonlocal is_eval_tuple
eval_tsrs = inp if attribute_to_layer_input else out
is_eval_tuple = isinstance(eval_tsrs, tuple)

if not is_eval_tuple:
eval_tsrs = (eval_tsrs,)
with lock:
Expand All @@ -220,7 +230,11 @@ def forward_hook(module, inp, out=None):
eval_tsrs_to_return = tuple(eval_tsr.clone() for eval_tsr in eval_tsrs)
if not is_eval_tuple:
eval_tsrs_to_return = eval_tsrs_to_return[0]
return eval_tsrs_to_return
if (
forward_hook_with_return_excl_modules is None
or type(module) not in forward_hook_with_return_excl_modules
):
return eval_tsrs_to_return
else:
saved_layer[eval_tsrs[0].device] = tuple(
eval_tsr.clone() for eval_tsr in eval_tsrs
Expand Down Expand Up @@ -396,6 +410,9 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Tuple[
Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...], Literal[True, False]
]:
Expand All @@ -413,6 +430,9 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Literal[True, False]]:
...

Expand All @@ -427,6 +447,9 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Union[
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], bool],
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...], bool],
Expand Down Expand Up @@ -489,6 +512,7 @@ def compute_layer_gradients_and_eval(
target_ind=target_ind,
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
forward_hook_with_return_excl_modules=forward_hook_with_return_excl_modules,
forward_hook_with_return=True,
)
assert output[0].numel() == 1, (
Expand Down
12 changes: 12 additions & 0 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,18 @@ def maxpool(
list(cast(torch.Size, module.input.shape)),
),
)
if grad_input[0].shape != inputs.shape:
raise AssertionError(
"A problem occurred during maxpool modul's backward pass. "
"The gradients with respect to inputs include only a "
"subset of inputs. More details about this issue can "
"be found here: "
"https://pytorch.org/docs/stable/"
"nn.html#torch.nn.Module.register_backward_hook "
"This can happen for example if you attribute to the outputs of a "
"MaxPool. As a workaround, please, attribute to the inputs of "
"the following layer."
)

new_grad_inp = torch.where(
abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
Expand Down
3 changes: 2 additions & 1 deletion captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from ..._core.deep_lift import DeepLift, DeepLiftShap
from ..._core.deep_lift import SUPPORTED_NON_LINEAR, DeepLift, DeepLiftShap
from ..._utils.attribution import LayerAttribution
from ..._utils.common import (
_call_custom_attribution_func,
Expand Down Expand Up @@ -309,6 +309,7 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
inputs,
attribute_to_layer_input=attribute_to_layer_input,
output_fn=lambda out: chunk_output_fn(out),
forward_hook_with_return_excl_modules=list(SUPPORTED_NON_LINEAR.keys()),
)

attr_inputs = tuple(map(lambda attr: attr[0], attrs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
assertTensorTuplesAlmostEqual,
)
from ...helpers.basic_models import (
BasicModel_ConvNet,
BasicModel_ConvNet_MaxPool3d,
BasicModel_MultiLayer,
LinearMaxPoolLinearModel,
ReLULinearModel,
Expand Down Expand Up @@ -206,6 +208,35 @@ def test_lin_maxpool_lin_classification(self) -> None:
assertArraysAlmostEqual(cast(Tensor, attrs).detach().numpy(), expected)
assertArraysAlmostEqual(delta.detach().numpy(), expected_delta)

def test_convnet_maxpool2d_classification(self) -> None:
inputs = 100 * torch.randn(2, 1, 10, 10)

model = BasicModel_ConvNet()
model.eval()

dl = LayerDeepLift(model, model.pool1)
dl2 = LayerDeepLift(model, model.conv2)

attr = dl.attribute(inputs, target=0)
attr2 = dl2.attribute(inputs, target=0, attribute_to_layer_input=True)

self.assertTrue(cast(Tensor, attr).sum() == cast(Tuple, attr2)[0].sum())

def test_convnet_maxpool3d_classification(self) -> None:
inputs = 100 * torch.randn(2, 1, 10, 10, 10)

model = BasicModel_ConvNet_MaxPool3d()
model.eval()

dl = LayerDeepLift(model, model.pool1)
dl2 = LayerDeepLift(model, model.conv2)
# with self.assertRaises(AssertionError) doesn't run with Cicle CI
# the error is being converted into RuntimeError

attr = dl.attribute(inputs, target=0, attribute_to_layer_input=False)
attr2 = dl2.attribute(inputs, target=0, attribute_to_layer_input=True)
self.assertTrue(cast(Tensor, attr).sum() == cast(Tuple, attr2)[0].sum())

def _relu_custom_attr_func_assert(
self,
attr_method: Union[LayerDeepLift, LayerDeepLiftShap],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap

from ...helpers.basic import BaseTest, assertTensorAlmostEqual
from ...helpers.basic_models import ReLULinearModel
from ..layer.test_layer_deeplift_basic import (
from ...helpers.basic_models import (
BasicModel_ConvNet,
BasicModel_ConvNet_MaxPool3d,
LinearMaxPoolLinearModel,
ReLULinearModel,
)
from ..layer.test_layer_deeplift import (
_create_inps_and_base_for_deeplift_neuron_layer_testing,
_create_inps_and_base_for_deepliftshap_neuron_layer_testing,
)
Expand Down Expand Up @@ -137,3 +142,48 @@ def custom_attr_func(
)
assertTensorAlmostEqual(self, attr[0], expected[0], 0.0)
assertTensorAlmostEqual(self, attr[1], expected[1], 0.0)

def test_lin_maxpool_lin_classification(self) -> None:
inputs = torch.ones(2, 4)
baselines = torch.tensor([[1, 2, 3, 9], [4, 8, 6, 7]]).float()

model = LinearMaxPoolLinearModel()
ndl = NeuronDeepLift(model, model.pool1)
attr = ndl.attribute(inputs, neuron_index=(0), baselines=baselines)

ndl2 = NeuronDeepLift(model, model.lin2)
attr2 = ndl2.attribute(
inputs,
neuron_index=(0),
baselines=baselines,
attribute_to_neuron_input=True,
)
assertTensorAlmostEqual(self, attr, attr2)

def test_convnet_maxpool2d_classification(self) -> None:
inputs = 100 * torch.randn(2, 1, 10, 10)
model = BasicModel_ConvNet()

ndl = NeuronDeepLift(model, model.pool1)
attr = ndl.attribute(inputs, neuron_index=(0, 0, 0))

ndl2 = NeuronDeepLift(model, model.conv2)
attr2 = ndl2.attribute(
inputs, neuron_index=(0, 0, 0), attribute_to_neuron_input=True
)

assertTensorAlmostEqual(self, attr.sum(), attr2.sum())

def test_convnet_maxpool3d_classification(self) -> None:
inputs = 100 * torch.randn(2, 1, 10, 10, 10)
model = BasicModel_ConvNet_MaxPool3d()

ndl = NeuronDeepLift(model, model.pool1)
attr = ndl.attribute(inputs, neuron_index=(0, 0, 0, 0))

ndl2 = NeuronDeepLift(model, model.conv2)
attr2 = ndl2.attribute(
inputs, neuron_index=(0, 0, 0, 0), attribute_to_neuron_input=True
)

assertTensorAlmostEqual(self, attr.sum(), attr2.sum())

0 comments on commit be3a2ab

Please sign in to comment.