Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-leaf tensor warning in some attribution algorithms #491

Closed
hal-314 opened this issue Oct 13, 2020 · 1 comment
Closed

non-leaf tensor warning in some attribution algorithms #491

hal-314 opened this issue Oct 13, 2020 · 1 comment
Assignees

Comments

@hal-314
Copy link

hal-314 commented Oct 13, 2020

🐛 Bug

I get the following warning when using the Saliency or InputXGradient attribution method but not with IntegratedGradients or GradientShap:

UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "

To Reproduce

Execute this:

import warnings
import torch
import torch.nn as nn

from captum.attr import Saliency, IntegratedGradients, InputXGradient
from captum.attr import configure_interpretable_embedding_layer

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.embedding = nn.Embedding(10,3)
        self.fc = nn.Linear(5+3, 2)

    def forward(self, x_cat, x):
        x_cat = self.embedding(x_cat)
        x = torch.cat([x, x_cat], dim=1)
        x = self.fc(x)
        return x

def sal(model, X, node_index):
    sal = Saliency(model)
    #sal = IntegratedGradients(model)
    #sal = InputXGradient(model)
    grads = sal.attribute(X, target=node_index)
    return grads


net = Net()

X_cont = torch.rand(1, 5)
X_cat = torch.randint(0,9, (1,))

X_cont.requires_grad = True

net(X_cat, X_cont)

#with torch.no_grad(): # <- Uncomment to remove warnings. 
X_cat_emb = net.embedding(X_cat)

X_cat_emb.requires_grad_()

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    emb_net = configure_interpretable_embedding_layer(net)

# It's the same if we compute embeddings from InterpretableEmbeddingBase like Bert tutorial
#X_cat_emb2 = emb_net.indices_to_embeddings(X_cat)
#X_cat_emb2.requires_grad_()


attr = sal(net, (X_cat_emb, X_cont), 0)
#attr = sal(net, (X_cat_emb2, X_cont), 0)

Expected behavior

All algorithms should behave consistently. I think, captum shouldn't raise the warnings or correct tutorials that use models with embeddings + add comment in the api docs.

Environment

Pytorch 1.6.0 + Captum 0.2 + Ubuntu 20.04

Additional context

I tested with Saliency, InputXGradient, IntegratedGradients and GradientShap gradient methods.
Finally, I think that this bug is similar to #421

facebook-github-bot pushed a commit that referenced this issue Jan 27, 2021
Summary:
This removes the resetting of grad attribute to zero, which is causing warnings as mentioned in #491 and #421 . Based on torch [documentation](https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad), resetting of grad is only needed when using torch.autograd.backward, which accumulates results into the grad attribute for leaf nodes. Since we only utilize torch.autograd.grad (with only_inputs always set to True), the gradients obtained in Captum are never actually accumulated into grad attributes, so resetting the attribute is not actually necessary.

This also adds a test to confirm that the grad attribute is not altered when gradients are utilized through Saliency.

Pull Request resolved: #597

Reviewed By: bilalsal

Differential Revision: D26079970

Pulled By: vivekmig

fbshipit-source-id: f7ccee02a17f66ee75e2176f1b328672b057dbfa
@NarineK
Copy link
Contributor

NarineK commented Jan 28, 2021

Fixed in the PR: #597

@NarineK NarineK closed this as completed Jan 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants