Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerDeepLift fails when used on a MaxPooling layer? #382

Closed
Holt59 opened this issue May 14, 2020 · 8 comments
Closed

LayerDeepLift fails when used on a MaxPooling layer? #382

Holt59 opened this issue May 14, 2020 · 8 comments

Comments

@Holt59
Copy link

Holt59 commented May 14, 2020

I am trying to use LayerDeepLift on multiple layers of a VGG16 model from torchvision.models. It works for all layers except MaxPooling2D layers.

The following (layer 23 is a MaxPool2d layer):

model = torchvision.models.vgg16(pretrained=True)
u = captum.attr.LayerDeepLift(
    model, list(model.features.children())[23]).attribute(
        torch_im[None, ...], target=156)[0]

Raises the following:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-66-668b5d33db17> in <module>
----> 1 u = captum.attr.LayerDeepLift(model, list(model.features.children())[23]).attribute(torch_im[None, ...], target=156)[0]

i:\languages\python\envs\deel-torch\lib\site-packages\captum\attr\_core\layer\layer_deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, attribute_to_layer_input, custom_attribution_func)
    306             inputs,
    307             attribute_to_layer_input=attribute_to_layer_input,
--> 308             output_fn=lambda out: chunk_output_fn(out),
    309         )
    310

i:\languages\python\envs\deel-torch\lib\site-packages\captum\attr\_utils\gradient.py in compute_layer_gradients_and_eval(forward_fn, layer, inputs, target_ind, additional_forward_args, gradient_neuron_index, device_ids, attribute_to_layer_input, output_fn)
    517             for layer_tensor in saved_layer[device_id]
    518         )
--> 519         saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)
    520         saved_grads = [
    521             saved_grads[i : i + num_tensors]

i:\languages\python\envs\deel-torch\lib\site-packages\torch\autograd\__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    155     return Variable._execution_engine.run_backward(
    156         outputs, grad_outputs, retain_graph, create_graph,
--> 157         inputs, allow_unused)
    158
    159

i:\languages\python\envs\deel-torch\lib\site-packages\captum\attr\_core\deep_lift.py in _backward_hook(self, module, grad_input, grad_output, eps)
    461         multipliers = tuple(
    462             SUPPORTED_NON_LINEAR[type(module)](
--> 463                 module, module.input, module.output, grad_input, grad_output, eps=eps
    464             )
    465         )

i:\languages\python\envs\deel-torch\lib\site-packages\captum\attr\_core\deep_lift.py in maxpool2d(module, inputs, outputs, grad_input, grad_output, eps)
    920         grad_input,
    921         grad_output,
--> 922         eps=eps,
    923     )
    924

i:\languages\python\envs\deel-torch\lib\site-packages\captum\attr\_core\deep_lift.py in maxpool(module, pool_func, unpool_func, inputs, outputs, grad_input, grad_output, eps)
   1002
   1003     new_grad_inp = torch.where(
-> 1004         abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
   1005     )
   1006     # If the module is invalid, save the newly computed gradients

RuntimeError: The size of tensor a (28) must match the size of tensor b (14) at non-singleton dimension 3

It works on all layers except the MaxPool2d layers of vgg16.features (it works with the average pooling layer).

I am not sure if this is a restriction of DeepLift or an error in the implementation?

Also, when the error occurs, the model seems to be left in some weird state as re-using it leads to IndexError: tuple index out of range (even with a brand new captum.attr.LayerDeepLift instance).

@NarineK
Copy link
Contributor

NarineK commented May 14, 2020

Hi @Holt59, thank you for the question. That's interesting! I'll debug it.
In terms of the model state: Those are the hooks that perhaps aren't getting removed. We actually recently fixed it so that we remove the hooks for all cases. For that you might need github version.

@bilalsal
Copy link
Contributor

Hi @Holt59.
While @NarineK looks into the main issue, I thought to refer you to #370 regarding re-using the model after an error occurs.
The new error is likely related to dangling hooks and is fixed in master as explained in the comments.
Hope this helps

@NarineK
Copy link
Contributor

NarineK commented May 16, 2020

@Holt59, I've been debugging this issue. There seem to be some inconsistencies in the backward pass. In the meanwhile, as a workaround, if you want to attribute to the inputs of the MaxPool2D layer it will work. By default we attribute to the outputs of the layer.

model = torchvision.models.vgg16(pretrained=True)
u = captum.attr.LayerDeepLift(
    model, list(model.features.children())[23]).attribute(
        torch_im[None, ...], target=156, attribute_to_layer_input=True)[0]

@NarineK
Copy link
Contributor

NarineK commented May 16, 2020

Actually what you were doing will be equivalent to:

u = LayerDeepLift(
    model, list(model.features.children())[24]).attribute(
        torch_im, target=156, attribute_to_layer_input=True)[0]

as a workaround

@NarineK
Copy link
Contributor

NarineK commented May 24, 2020

@Holt59 , did that workaround work for you?

@Holt59
Copy link
Author

Holt59 commented May 25, 2020

The workaround seems to work but I cannot use it in my code base like this since I am trying to compute attributions for multiple layers (and I don't know the following layer). But that's not a big issue, I'm not particularly interested in the MaxPool layers, I can leave them out.

@NarineK
Copy link
Contributor

NarineK commented May 28, 2020

@Holt59, this PR #390 will fix the problem with MaxPool. To give more context, this problem happened because in the forward_hook we return cloned output tensor and that made the MaxPool modules complex. Since there is a bug related to complex modules in PyTorch and backward_hook, that is, returned input gradients represent only a subset of inputs, it wasn't able to compute the multipliers correctly.

More details about the issue can be found here: https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook

Another point that I wanted to bring up is: In VGG the modules might get reused (you might want to check that). We want to make sure that this isn't happening for the layer algorithms and DeepLift.
If the activations get reused. You can simply redefine the architecture (that's easy to do). More info about it can be found here:

#378 (comment)

facebook-github-bot pushed a commit that referenced this issue Jun 24, 2020
…n issue for MaxPool (#390)

Summary:
Related to the issue: #382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook.

Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it.

Added test cases for layer and neuron use cases.
Pull Request resolved: #390

Reviewed By: edward-io

Differential Revision: D22197030

Pulled By: NarineK

fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
edward-io pushed a commit to edward-io/captum that referenced this issue Jun 30, 2020
…n issue for MaxPool (pytorch#390)

Summary:
Related to the issue: pytorch#382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook.

Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it.

Added test cases for layer and neuron use cases.
Pull Request resolved: pytorch#390

Reviewed By: edward-io

Differential Revision: D22197030

Pulled By: NarineK

fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
@NarineK
Copy link
Contributor

NarineK commented Jul 2, 2020

This got fixed through: #390

@NarineK NarineK closed this as completed Jul 2, 2020
NarineK added a commit to NarineK/captum-1 that referenced this issue Nov 19, 2020
…n issue for MaxPool (pytorch#390)

Summary:
Related to the issue: pytorch#382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook.

Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it.

Added test cases for layer and neuron use cases.
Pull Request resolved: pytorch#390

Reviewed By: edward-io

Differential Revision: D22197030

Pulled By: NarineK

fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants