Skip to content

Commit

Permalink
LayerDeepLift Hook Fix (pytorch#415)
Browse files Browse the repository at this point in the history
Summary:
This avoids adding a hook to the target layer non-linearity when using LayerDeepLift if attributing with respect to layer output. This fixes an issue with in-place modules caused by skipping cloning and adds a corresponding test case.
Pull Request resolved: pytorch#415

Reviewed By: edward-io

Differential Revision: D22241851

Pulled By: vivekmig

fbshipit-source-id: 771a7cbb3cf77438bba901237defe937a26c415c
  • Loading branch information
vivekmig authored and facebook-github-bot committed Jun 29, 2020
1 parent c88f0f2 commit dd82824
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 30 deletions.
25 changes: 1 addition & 24 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
forward_hook_with_return: Literal[False] = False,
) -> Tuple[Dict[device, Tuple[Tensor, ...]], Literal[True, False]]:
...
Expand All @@ -171,9 +168,6 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
*,
forward_hook_with_return: Literal[True],
) -> Tuple[Dict[device, Tuple[Tensor, ...]], Tensor, Literal[True, False]]:
Expand All @@ -187,9 +181,6 @@ def _forward_layer_distributed_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
forward_hook_with_return: bool = False,
) -> Union[
Tuple[Dict[device, Tuple[Tensor, ...]], Tensor, bool],
Expand Down Expand Up @@ -230,11 +221,7 @@ def forward_hook(module, inp, out=None):
eval_tsrs_to_return = tuple(eval_tsr.clone() for eval_tsr in eval_tsrs)
if not is_eval_tuple:
eval_tsrs_to_return = eval_tsrs_to_return[0]
if (
forward_hook_with_return_excl_modules is None
or type(module) not in forward_hook_with_return_excl_modules
):
return eval_tsrs_to_return
return eval_tsrs_to_return
else:
saved_layer[eval_tsrs[0].device] = tuple(
eval_tsr.clone() for eval_tsr in eval_tsrs
Expand Down Expand Up @@ -410,9 +397,6 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Tuple[
Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...], Literal[True, False]
]:
Expand All @@ -430,9 +414,6 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Literal[True, False]]:
...

Expand All @@ -447,9 +428,6 @@ def compute_layer_gradients_and_eval(
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
forward_hook_with_return_excl_modules: Union[
None, List[typing.Type[Module]]
] = None,
) -> Union[
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], bool],
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...], bool],
Expand Down Expand Up @@ -512,7 +490,6 @@ def compute_layer_gradients_and_eval(
target_ind=target_ind,
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
forward_hook_with_return_excl_modules=forward_hook_with_return_excl_modules,
forward_hook_with_return=True,
)
assert output[0].numel() == 1, (
Expand Down
8 changes: 6 additions & 2 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,12 @@ def _can_register_hook(self, module: Module) -> bool:
or not self._is_non_linear(module)
)

def _register_hooks(self, module: Module) -> None:
if not self._can_register_hook(module):
def _register_hooks(
self, module: Module, attribute_to_layer_input: bool = True
) -> None:
if not self._can_register_hook(module) or (
not attribute_to_layer_input and module is self.layer # type: ignore
):
return
# adds forward hook to leaf nodes that are non-linear
forward_handle = module.register_forward_hook(self._forward_hook)
Expand Down
11 changes: 7 additions & 4 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from ..._core.deep_lift import SUPPORTED_NON_LINEAR, DeepLift, DeepLiftShap
from ..._core.deep_lift import DeepLift, DeepLiftShap
from ..._utils.attribution import LayerAttribution
from ..._utils.common import (
_call_custom_attribution_func,
Expand Down Expand Up @@ -283,7 +283,11 @@ def attribute(
try:
main_model_hooks = self._hook_main_model()

self.model.apply(self._register_hooks)
self.model.apply(
lambda mod: self._register_hooks(
mod, attribute_to_layer_input=attribute_to_layer_input
)
)

additional_forward_args = _format_additional_forward_args(
additional_forward_args
Expand All @@ -298,7 +302,7 @@ def attribute(
additional_forward_args,
)

def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
if isinstance(out, Tensor):
return out.chunk(2)
return tuple(out_sub.chunk(2) for out_sub in out)
Expand All @@ -309,7 +313,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
inputs,
attribute_to_layer_input=attribute_to_layer_input,
output_fn=lambda out: chunk_output_fn(out),
forward_hook_with_return_excl_modules=list(SUPPORTED_NON_LINEAR.keys()),
)

attr_inputs = tuple(map(lambda attr: attr[0], attrs))
Expand Down
14 changes: 14 additions & 0 deletions tests/attr/layer/test_layer_deeplift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...helpers.basic_models import (
BasicModel_ConvNet,
BasicModel_ConvNet_MaxPool3d,
BasicModel_MaxPool_ReLU,
BasicModel_MultiLayer,
LinearMaxPoolLinearModel,
ReLULinearModel,
Expand Down Expand Up @@ -92,6 +93,19 @@ def test_relu_deeplift_with_custom_attr_func(self) -> None:
attr_method = LayerDeepLift(model, model.l3)
self._relu_custom_attr_func_assert(attr_method, inputs, baselines, [[2.0]])

def test_inplace_maxpool_relu_with_custom_attr_func(self) -> None:
model = BasicModel_MaxPool_ReLU(inplace=True)
inp = torch.tensor([[[1.0, 2.0, -4.0], [-3.0, -2.0, -1.0]]])
dl = LayerDeepLift(model, model.maxpool)

def custom_att_func(mult, inp, baseline):
assertTensorAlmostEqual(self, mult[0], [[1.0, 0.0]])
assertTensorAlmostEqual(self, inp[0], [[2.0, -1.0]])
assertTensorAlmostEqual(self, baseline[0], [[0.0, 0.0]])
return mult

dl.attribute(inp, custom_attribution_func=custom_att_func)

def test_linear_layer_deeplift_batch(self) -> None:
model = ReLULinearModel(inplace=True)
_, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing()
Expand Down
10 changes: 10 additions & 0 deletions tests/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ def forward(self, inputs, sparse_list):
).sum()


class BasicModel_MaxPool_ReLU(nn.Module):
def __init__(self, inplace=False):
super().__init__()
self.maxpool = nn.MaxPool1d(3)
self.relu = nn.ReLU(inplace=inplace)

def forward(self, x):
return self.relu(self.maxpool(x)).sum(dim=1)


class TanhDeepLiftModel(nn.Module):
r"""
Same as the ReLUDeepLiftModel, but with activations
Expand Down

0 comments on commit dd82824

Please sign in to comment.