Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Undesirable behavior of LayerActivation in networks with inplace ReLUs #156

Closed
mrsalehi opened this issue Nov 3, 2019 · 2 comments
Closed
Assignees
Labels
bug Something isn't working triaged

Comments

@mrsalehi
Copy link

mrsalehi commented Nov 3, 2019

Hi,
I was trying to use captum.attr._core.layer_activation.LayerActivation to get the activation of the first convolutional layer in a simple model. Here is my code:

torch.manual_seed(23)
np.random.seed(23)
model = nn.Sequential(nn.Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True),
                      nn.Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True))

layer_act = LayerActivation(model, model[0])
input = torch.randn(1, 3, 5, 5)
mylayer = model[0]
print(torch.norm(mylayer(input) - layer_act.attribute(input), p=2))

In fact, I have computed the activation in two different ways and compared them afterwards. Obviously, I expected a value close to zero to be printed as the output, however, this is what I got:

tensor(3.4646, grad_fn=<NormBackward0>)

I hypothesize that the inplace ReLU layer after the convolutional layer acts on its output since there were many zeros in the activation computed by Captum ( i.e. layer_act.attribute(input)). In fact, when I changed the architecture of the network to the following:

model = nn.Sequential(nn.Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(),
                      nn.Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True))

then the outputs matched.

System information

  • Python 3.7.0
  • torch 1.3.0
  • Captum 0.1.0
@vivekmig vivekmig self-assigned this Nov 3, 2019
@vivekmig
Copy link
Contributor

vivekmig commented Nov 3, 2019

Hi @mrsalehi, yes, this is a bug, thanks for pointing it out! We will push a fix for this soon.

@vivekmig vivekmig added bug Something isn't working triaged labels Nov 3, 2019
facebook-github-bot pushed a commit that referenced this issue Nov 11, 2019
Summary:
This PR fixes neuron / layer attributions with in-place operations by keeping appropriate clones of intermediate values to ensure that they are not modified by future operations.

Addresses Issue: #156
Pull Request resolved: #165

Differential Revision: D18435244

Pulled By: vivekmig

fbshipit-source-id: c658baded1f781710f5a363a8b3652fd3333ca20
@vivekmig
Copy link
Contributor

Fix has been merged here: 5bf06ba

miguelmartin75 pushed a commit to miguelmartin75/captum that referenced this issue Dec 20, 2019
)

Summary:
This PR fixes neuron / layer attributions with in-place operations by keeping appropriate clones of intermediate values to ensure that they are not modified by future operations.

Addresses Issue: pytorch#156
Pull Request resolved: pytorch#165

Differential Revision: D18435244

Pulled By: vivekmig

fbshipit-source-id: c658baded1f781710f5a363a8b3652fd3333ca20
miguelmartin75 pushed a commit to miguelmartin75/captum that referenced this issue Dec 20, 2019
)

Summary:
This PR fixes neuron / layer attributions with in-place operations by keeping appropriate clones of intermediate values to ensure that they are not modified by future operations.

Addresses Issue: pytorch#156
Pull Request resolved: pytorch#165

Differential Revision: D18435244

Pulled By: vivekmig

fbshipit-source-id: c658baded1f781710f5a363a8b3652fd3333ca20
NarineK pushed a commit to NarineK/captum-1 that referenced this issue Nov 19, 2020
)

Summary:
This PR fixes neuron / layer attributions with in-place operations by keeping appropriate clones of intermediate values to ensure that they are not modified by future operations.

Addresses Issue: pytorch#156
Pull Request resolved: pytorch#165

Differential Revision: D18435244

Pulled By: vivekmig

fbshipit-source-id: c658baded1f781710f5a363a8b3652fd3333ca20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged
Projects
None yet
Development

No branches or pull requests

2 participants