Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from register_full_backward_hooks to tensor hooks #979

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,20 +660,52 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any:

def _register_backward_hook(
module: Module, hook: Callable, attr_obj: Any
) -> torch.utils.hooks.RemovableHandle:
) -> List[torch.utils.hooks.RemovableHandle]:
# Special case for supporting output attributions for neuron methods
# This can be removed after deprecation of neuron output attributions
# for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop
# in v0.6.0
vivekmig marked this conversation as resolved.
Show resolved Hide resolved
if (
hasattr(attr_obj, "skip_new_hook_layer")
and attr_obj.skip_new_hook_layer == module
):
return module.register_backward_hook(hook)
grad_out = {}

if torch.__version__ >= "1.9":
# Only supported for torch >= 1.9
return module.register_full_backward_hook(hook)
else:
# Fallback for previous versions of PyTorch
return module.register_backward_hook(hook)
def forward_hook(
module: Module,
inp: Union[Tensor, Tuple[Tensor, ...]],
out: Union[Tensor, Tuple[Tensor, ...]],
) -> None:
nonlocal grad_out
grad_out = {}

def output_tensor_hook(output_grad: Tensor) -> None:
grad_out[output_grad.device] = output_grad

if isinstance(out, tuple):
assert (
len(out) == 1
), "Backward hooks not supported for module with >1 output"
out[0].register_hook(output_tensor_hook)
else:
out.register_hook(output_tensor_hook)

def pre_hook(module, inp):
vivekmig marked this conversation as resolved.
Show resolved Hide resolved
def input_tensor_hook(input_grad: Tensor) -> Tensor:
if len(grad_out) == 0:
return
hook_out = hook(module, input_grad, grad_out[input_grad.device])

if hook_out is not None:
return hook_out[0] if isinstance(hook_out, tuple) else hook_out

if isinstance(inp, tuple):
assert (
len(inp) == 1
), "Backward hooks not supported for module with >1 input"
inp[0].register_hook(input_tensor_hook)
return inp[0].clone()
else:
inp.register_hook(input_tensor_hook)
return inp.clone()
vivekmig marked this conversation as resolved.
Show resolved Hide resolved

return [
module.register_forward_pre_hook(pre_hook),
module.register_forward_hook(forward_hook),
]
8 changes: 7 additions & 1 deletion captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def apply_gradient_requirements(
inputs: Tuple[Tensor, ...], warn: bool = True
inputs: Tuple[Tensor, ...], warn: bool = True, skip_non_tensor: bool = False
) -> List[bool]:
"""
Iterates through tuple on input tensors and sets requires_grad to be true on
Expand All @@ -37,6 +37,10 @@ def apply_gradient_requirements(
), "Inputs should be wrapped in a tuple prior to preparing for gradients"
grad_required = []
for index, input in enumerate(inputs):
if skip_non_tensor and not isinstance(input, torch.Tensor):
grad_required.append(None)
continue

assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
vivekmig marked this conversation as resolved.
Show resolved Hide resolved
grad_required.append(input.requires_grad)
inputs_dtype = input.dtype
Expand Down Expand Up @@ -813,6 +817,8 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
parameters of the i-th layer, for the j-th member of the minibatch.
"""
with torch.autograd.set_grad_enabled(True):
inputs = tuple(inp.clone() for inp in inputs)
apply_gradient_requirements(inputs, skip_non_tensor=True)
sample_grad_wrapper = SampleGradientWrapper(model)
try:
sample_grad_wrapper.add_hooks()
Expand Down
2 changes: 1 addition & 1 deletion captum/_utils/sample_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _register_module_hooks(self, module: torch.nn.Module):
self.forward_hooks.append(
module.register_forward_hook(self._forward_hook_fn)
)
self.backward_hooks.append(
self.backward_hooks.extend(
_register_backward_hook(module, self._backward_hook_fn, None)
)

Expand Down
146 changes: 24 additions & 122 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,6 @@
from torch.utils.hooks import RemovableHandle


# Check if module backward hook can safely be used for the module that produced
# this inputs / outputs mapping
def _check_valid_module(inputs_grad_fn, outputs) -> bool:
def is_output_cloned(output_fn, input_grad_fn) -> bool:
"""
Checks if the output has been cloned. This happens especially in case of
layer deeplift.
"""
return (
output_fn[0].next_functions is not None
and output_fn[0].next_functions[0][0] == input_grad_fn
)

curr_fn = outputs.grad_fn
first_next = curr_fn.next_functions[0]
try:
# if `inputs` in the input to the network then the grad_fn is None and
# for that input backward_hook isn't computed. That's the reason why we
# need to check on `inputs_grad_fns[first_next[1]]` being None.
return (
inputs_grad_fn is None
or first_next[0] == inputs_grad_fn
or is_output_cloned(first_next, inputs_grad_fn)
)
except IndexError:
return False


class DeepLift(GradientAttribution):
r"""
Implements DeepLIFT algorithm based on the following paper:
Expand Down Expand Up @@ -112,10 +84,7 @@ def __init__(
r"""
Args:

model (nn.Module): The reference to PyTorch model instance. Model cannot
contain any in-place nonlinear submodules; these are not
supported by the register_full_backward_hook PyTorch API
starting from PyTorch v1.9.
model (nn.Module): The reference to PyTorch model instance.
multiply_by_inputs (bool, optional): Indicates whether to factor
model inputs' multiplier in the final attribution scores.
In the literature this is also known as local vs global
Expand Down Expand Up @@ -430,25 +399,6 @@ def _forward_pre_hook(
"""
inputs = _format_tensor_into_tuples(inputs)
module.input = inputs[0].clone().detach()
module.input_grad_fns = inputs[0].grad_fn # type: ignore

def tensor_backward_hook(grad):
if module.saved_grad is None:
raise RuntimeError(
"""Module {} was detected as not supporting correctly module
backward hook. You should modify your hook to ignore the given
grad_inputs (recompute them by hand if needed) and save the
newly computed grad_inputs in module.saved_grad. See MaxPool1d
as an example.""".format(
module
)
)
return module.saved_grad

# the hook is set by default but it will be used only for
# failure cases and will be removed otherwise
handle = inputs[0].register_hook(tensor_backward_hook)
module.input_hook = handle

def _forward_hook(
self,
Expand All @@ -462,30 +412,13 @@ def _forward_hook(
"""
outputs = _format_tensor_into_tuples(outputs)
module.output = outputs[0].clone().detach()
if not _check_valid_module(module.input_grad_fns, outputs[0]):
warnings.warn(
"""An invalid module {} is detected. Saved gradients will
be used as the gradients of the module's input tensor.
See MaxPool1d as an example.""".format(
module
)
)
module.is_invalid = True # type: ignore
module.saved_grad = None # type: ignore
self.forward_handles.append(cast(RemovableHandle, module.input_hook))
else:
module.is_invalid = False # type: ignore
# removing the hook if there is no failure case
cast(RemovableHandle, module.input_hook).remove()
del module.input_hook
del module.input_grad_fns

def _backward_hook(
self,
module: Module,
grad_input: Union[Tensor, Tuple[Tensor, ...]],
grad_output: Union[Tensor, Tuple[Tensor, ...]],
):
grad_input: Tensor,
grad_output: Tensor,
) -> Tensor:
r"""
`grad_input` is the gradient of the neuron with respect to its input
`grad_output` is the gradient of the neuron with respect to its output
Expand All @@ -506,15 +439,14 @@ def _backward_hook(
"Please, ensure that module is being used only once in the "
"network.".format(module)
)
multipliers = tuple(
SUPPORTED_NON_LINEAR[type(module)](
module,
module.input,
module.output,
grad_input,
grad_output,
eps=self.eps,
)

multipliers = SUPPORTED_NON_LINEAR[type(module)](
module,
module.input,
module.output,
grad_input,
grad_output,
eps=self.eps,
)
# remove all the properies that we set for the inputs and output
del module.input
Expand Down Expand Up @@ -545,10 +477,10 @@ def _register_hooks(
# adds forward hook to leaf nodes that are non-linear
forward_handle = module.register_forward_hook(self._forward_hook)
pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook)
backward_handle = _register_backward_hook(module, self._backward_hook, self)
backward_handles = _register_backward_hook(module, self._backward_hook, self)
self.forward_handles.append(forward_handle)
self.forward_handles.append(pre_forward_handle)
self.backward_handles.append(backward_handle)
self.backward_handles.extend(backward_handles)

def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None:
for handle in extra_hooks_to_remove:
Expand Down Expand Up @@ -625,9 +557,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
r"""
Args:

model (nn.Module): The reference to PyTorch model instance. Model cannot
contain any in-place nonlinear submodules; these are not
supported by the register_full_backward_hook PyTorch API.
model (nn.Module): The reference to PyTorch model instance.
multiply_by_inputs (bool, optional): Indicates whether to factor
model inputs' multiplier in the final attribution scores.
In the literature this is also known as local vs global
Expand Down Expand Up @@ -939,26 +869,18 @@ def nonlinear(
grad_input: Tensor,
grad_output: Tensor,
eps: float = 1e-10,
):
) -> Tensor:
r"""
grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
grad_output: (dLoss / dlayer_out)
https://github.com/pytorch/pytorch/issues/12331
"""
delta_in, delta_out = _compute_diffs(inputs, outputs)

new_grad_inp = list(grad_input)

# supported non-linear modules take only single tensor as input hence accessing
# only the first element in `grad_input` and `grad_output`
new_grad_inp[0] = torch.where(
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
new_grad_inp = torch.where(
abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
)

# If the module is invalid, save the newly computed gradients
# The original_grad_input will be overridden later in the Tensor hook
if module.is_invalid:
module.saved_grad = new_grad_inp[0]
return new_grad_inp


Expand All @@ -972,15 +894,14 @@ def softmax(
):
delta_in, delta_out = _compute_diffs(inputs, outputs)

new_grad_inp = list(grad_input)
grad_input_unnorm = torch.where(
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
)
# normalizing
n = grad_input[0].numel()
n = grad_input.numel()

# updating only the first half
new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
return new_grad_inp


Expand Down Expand Up @@ -1071,7 +992,7 @@ def maxpool(
module.ceil_mode,
True,
)
grad_output_updated = grad_output[0]
grad_output_updated = grad_output
unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk(
unpool_func(
grad_output_updated * delta_out,
Expand All @@ -1087,20 +1008,7 @@ def maxpool(
unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta])

# If the module is invalid, we need to recompute the grad_input
if module.is_invalid:
original_grad_input = grad_input
grad_input = (
unpool_func(
grad_output_updated,
indices,
module.kernel_size,
module.stride,
module.padding,
list(cast(torch.Size, module.input.shape)),
),
)
if grad_input[0].shape != inputs.shape:
if grad_input.shape != inputs.shape:
raise AssertionError(
"A problem occurred during maxpool modul's backward pass. "
"The gradients with respect to inputs include only a "
Expand All @@ -1116,13 +1024,7 @@ def maxpool(
new_grad_inp = torch.where(
abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
)
# If the module is invalid, save the newly computed gradients
# The original_grad_input will be overridden later in the Tensor hook
if module.is_invalid:
module.saved_grad = new_grad_inp
return original_grad_input
else:
return (new_grad_inp,)
return new_grad_inp


def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down
12 changes: 4 additions & 8 deletions captum/attr/_core/guided_backprop_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def attribute(

def _register_hooks(self, module: Module):
if isinstance(module, torch.nn.ReLU):
hook = _register_backward_hook(module, self._backward_hook, self)
self.backward_hooks.append(hook)
hooks = _register_backward_hook(module, self._backward_hook, self)
self.backward_hooks.extend(hooks)

def _backward_hook(
self,
Expand Down Expand Up @@ -121,9 +121,7 @@ def __init__(self, model: Module) -> None:
r"""
Args:

model (nn.Module): The reference to PyTorch model instance. Model cannot
contain any in-place ReLU submodules; these are not
supported by the register_full_backward_hook PyTorch API.
model (nn.Module): The reference to PyTorch model instance.
"""
ModifiedReluGradientAttribution.__init__(
self, model, use_relu_grad_output=False
Expand Down Expand Up @@ -234,9 +232,7 @@ def __init__(self, model: Module) -> None:
r"""
Args:

model (nn.Module): The reference to PyTorch model instance. Model cannot
contain any in-place ReLU submodules; these are not
supported by the register_full_backward_hook PyTorch API.
model (nn.Module): The reference to PyTorch model instance.
"""
ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True)

Expand Down
5 changes: 1 addition & 4 deletions captum/attr/_core/guided_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def __init__(
r"""
Args:

model (nn.Module): The reference to PyTorch model instance. Model cannot
contain any in-place ReLU submodules; these are not
supported by the register_full_backward_hook PyTorch API
starting from PyTorch v1.9.
model (nn.Module): The reference to PyTorch model instance.
layer (torch.nn.Module): Layer for which GradCAM attributions are computed.
Currently, only layers with a single tensor output are
supported.
Expand Down
Loading