Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unclear error message for RunTimeError in LRP #857

Open
katjahauser opened this issue Feb 10, 2022 · 1 comment
Open

Unclear error message for RunTimeError in LRP #857

katjahauser opened this issue Feb 10, 2022 · 1 comment
Assignees

Comments

@katjahauser
Copy link

katjahauser commented Feb 10, 2022

Dear developers,

I encountered an error in LRP (and LayerLRP) that is caused by using the same layer (in this case a pooling layer) twice in the model. The error message is not very helpful for debugging, though: "Function ThnnConv2DBackward returned an invalid gradient at index 0 - got [1, 16, 15, 15] but expected shape compatible with [1, 32, 6, 6]".
When using DeepLift on the same model, a more helpful error message is provided: "A Module MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network."

The error itself is in both cases fixed by introducing another pooling layer. I would therefore kindly ask, if you could adapt the LRP and LayerLRP error messages accordingly to make debugging easier.

Below, you find a minimal working example to reproduce the error messages.

Best wishes,
Katja Hauser


import torch.nn as nn
import torch.nn.functional as F
import torch
from captum.attr import LRP, DeepLift


class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.ff = nn.Linear(32*6*6, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.softmax(self.ff(x), dim=1)
        return x


if __name__ == "__main__":
    net = ImageClassifier()
    input = torch.randn(1, 1, 32, 32)
    try:
        lrp = LRP(net)
        attribution = lrp.attribute(input, target=5)
    except RuntimeError as e:
        print("LRP: ", e)
    try:
        dl = DeepLift(net)
        attribution = dl.attribute(input, target=5)
    except RuntimeError as e:
        print("DeepLift: ", e)
@vivekmig
Copy link
Contributor

Thanks for reporting this issue @katjahauser ! We have added a similar warning to LRP in #911 .

facebook-github-bot pushed a commit that referenced this issue Apr 6, 2022
Summary:
This addresses #857 , adding an error message to LRP in the case where module reuse is detected.

Pull Request resolved: #911

Reviewed By: NarineK, aobo-y

Differential Revision: D35120798

Pulled By: vivekmig

fbshipit-source-id: f14eac8e084f2dc1d6ce07436c3779e8a9132a44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants