Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-leaf tensor warning in NoiseTunnel usage #421

Closed
arnoldjulian opened this issue Jul 3, 2020 · 3 comments
Closed

non-leaf tensor warning in NoiseTunnel usage #421

arnoldjulian opened this issue Jul 3, 2020 · 3 comments

Comments

@arnoldjulian
Copy link

I get the following warning when using the saliency attribution method in combination with NoiseTunnel:

C:\Users\julia\Anaconda3\envs\torch\lib\site-packages\torch\tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "

A MWE:

import torch
import torch.nn as nn

from captum.attr import Saliency
from captum.attr import NoiseTunnel


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc(x)
        return x


def sal(model, X, node_index):
    sal = Saliency(model)
    grads = sal.attribute(X, node_index).squeeze().detach().numpy()
    return grads


def sal_smooth(model, X, node_index):
    sal = Saliency(model)
    nt = NoiseTunnel(sal)
    grads = nt.attribute(X, target=node_index, nt_type='smoothgrad_sq',
                         n_samples=2, stdevs=0.2).squeeze().detach().numpy()
    return grads


net = Net()

X = torch.rand(1, 5)
X.requires_grad = True

attr = sal_smooth(net, X, 0)

When using the saliency attribution method as defined above without NoiseTunnel I do not get any warning.

Based on the API reference on NoiseTunnel and Saliency I would assume that requires_grad should be set to True for the input. And indeed, when removing the requires_grad statement I get the following warning:

UserWarning: Input Tensor 0 did not already require gradients, required_grads has been set automatically.
  warnings.warn(

The same warning also appears when using the saliency attribution method without NoiseTunnel.
When following the same approach as in the MWE using IntegratedGradients instead of Saliency no warnings are thrown.

Is there a way to use the saliency attribution method in combination with NoiseTunnel without getting any warnings?

@NarineK
Copy link
Contributor

NarineK commented Jul 4, 2020

Thank you for brining this up, @arnoldjulian !
This happens because we call grad on a non-leaf tensor when we are trying to zero out the grads.
This reproduces it:

import torch

g = torch.rand(1,5)
g.requires_grad = True
f = g + torch.tensor(5.0)
f.grad

We can have a fix for this.

facebook-github-bot pushed a commit that referenced this issue Jul 13, 2020
Summary:
This will fix the warning error specifically related to NoiseTunnel in #421.
In addition to that I moved almost everything under no_grad in the attribute method. This will hopefully also help with runtime performance.
In the `_forward_layer_eval ` I had to add `grad_enabled ` flag in order to allow to enable the gradients externally. As it is also needed in `test_neuron_gradient.py` test case.

Pull Request resolved: #426

Reviewed By: vivekmig

Differential Revision: D22500566

Pulled By: NarineK

fbshipit-source-id: d3170e1711012593ff421b964a02e54532a95b13
@NarineK
Copy link
Contributor

NarineK commented Jul 14, 2020

Fixed in #426

@NarineK NarineK closed this as completed Jul 14, 2020
NarineK added a commit to NarineK/captum-1 that referenced this issue Nov 19, 2020
…#426)

Summary:
This will fix the warning error specifically related to NoiseTunnel in pytorch#421.
In addition to that I moved almost everything under no_grad in the attribute method. This will hopefully also help with runtime performance.
In the `_forward_layer_eval ` I had to add `grad_enabled ` flag in order to allow to enable the gradients externally. As it is also needed in `test_neuron_gradient.py` test case.

Pull Request resolved: pytorch#426

Reviewed By: vivekmig

Differential Revision: D22500566

Pulled By: NarineK

fbshipit-source-id: d3170e1711012593ff421b964a02e54532a95b13
facebook-github-bot pushed a commit that referenced this issue Jan 27, 2021
Summary:
This removes the resetting of grad attribute to zero, which is causing warnings as mentioned in #491 and #421 . Based on torch [documentation](https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad), resetting of grad is only needed when using torch.autograd.backward, which accumulates results into the grad attribute for leaf nodes. Since we only utilize torch.autograd.grad (with only_inputs always set to True), the gradients obtained in Captum are never actually accumulated into grad attributes, so resetting the attribute is not actually necessary.

This also adds a test to confirm that the grad attribute is not altered when gradients are utilized through Saliency.

Pull Request resolved: #597

Reviewed By: bilalsal

Differential Revision: D26079970

Pulled By: vivekmig

fbshipit-source-id: f7ccee02a17f66ee75e2176f1b328672b057dbfa
@dbl001
Copy link

dbl001 commented Jul 27, 2022

I'm still getting this warning

 % ipython
Python 3.9.12 (main, Jun  1 2022, 06:36:29) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.3.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch
   ...: 
   ...: g = torch.rand(1,5)
   ...: g.requires_grad = True
   ...: f = g + torch.tensor(5.0)
   ...: f.grad
<ipython-input-1-0c04367a67e5>:6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /Users/davidlaxer/pytorch/build/aten/src/ATen/core/TensorBody.h:483.)
  f.grad

In [2]: quit()

% python collect_env.py 
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A

OS: macOS 12.5 (x86_64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.9.12 (main, Jun  1 2022, 06:36:29)  [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: N/A
CUDA runtime version: Could not collect
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.3
[pip3] torch==1.13.0a0+gitdb0e121
[pip3] torchvision==0.14.0a0+e75a333
[conda] blas                      1.0                         mkl    anaconda
[conda] mkl                       2021.4.0           hecd8cb5_637    anaconda
[conda] mkl-include               2022.0.0           hecd8cb5_105    anaconda
[conda] mkl-service               2.4.0            py39h9ed2024_0    anaconda
[conda] mkl_fft                   1.3.1            py39h4ab4a9b_0    anaconda
[conda] mkl_random                1.2.2            py39hb2f4e1b_0    anaconda
[conda] numpy                     1.22.3           py39h2e5f0a9_0    anaconda
[conda] numpy-base                1.22.3           py39h3b1a694_0    anaconda
[conda] pytorch                   1.12.0                  py3.9_0    pytorch
[conda] torch                     1.13.0a0+git9506f9e          pypi_0    pypi
[conda] torchvision               0.14.0a0+e75a333          pypi_0    pypi
(AI-Feynman) davidlaxer@x86_64-apple-darwin13 pytorch % 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants