Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix layer_gradient_x_activation and add logging for metrics #643

Closed

Conversation

NarineK
Copy link
Contributor

@NarineK NarineK commented Mar 30, 2021

  • Adding logging for captum.metrics
  • currently layer_gradient_x_activation is failing for integer inputs. Adding datatype check before calling apply_gradient_requirements otherwise if attribute is called in torch.no_grad() context then it fails.

Copy link
Contributor

@vivekmig vivekmig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this 👍 ! Just one question / comment.

@@ -170,7 +170,10 @@ def attribute(
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
gradient_mask = apply_gradient_requirements(inputs)

if inputs[0].is_floating_point() or inputs[0].is_complex():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question on this, it seems like apply_gradient_requirements already checks input types and avoids enabling grads for any int / long inputs, is there some case that's not covered by that check?

Also, alternatively, it might make sense to remove the gradient requirement here and instead set requires grad on only the layer inputs / outputs within the forward hook in _forward_layer_distributed_eval based on a flag, since we only need gradients starting from the target layer. This will avoid requiring gradients before the target layer, and should work within torch.no_grad even if the original inputs are integers as long as the target layer can require gradients. What do you think?

Copy link
Contributor Author

@NarineK NarineK Apr 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thank you. I was thinking that we might have done that check already. apply_gradient_requirements does the check but it occasionally doesn't pass. Now, I cannot reproduce it but an external user pointed it out to me (later it went away for him) and I was able to reproduce it at that moment but right now I can't reproduce it. It might have been another pytorch version. Since apply_gradient_requirements has it, there is no need to do the check again but we'll keep an eye on this. It is likely that it will reoccur and we will investigate it further.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, sounds good! One issue I can imagine would be in cases like the example below, this would currently fail since the gradient requirement wouldn't be set on integer inputs. Is this similar to the issue you had in mind?

from captum.attr import LayerGradientXActivation
import torch.nn as nn

class TestModuleSingleEmbedding(nn.Module):
    def __init__(
        self, vocab_size: int = 1000, emb_dim: int = 128
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim)

    def forward(self, idx):
        return torch.sum(self.embed(idx)).reshape(1,)

mod = TestModuleSingleEmbedding()
for param in mod.parameters():
  param.requires_grad = False

lga = LayerGradientXActivation(mod, mod.embed)
idx = torch.tensor([1,2,3,4])
attr = lga.attribute(idx)

Setting the gradient requirement to within the forward hook in _forward_layer_distributed_eval should likely resolve this.

Copy link
Contributor Author

@NarineK NarineK Apr 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the error was related to input data type but I've seen this one too. This one is because tensors do not carry gradients and resulting output has no grad_fn. I've tried couple things for that as well. Where do you recommend to set 'gradient requirement' ? _forward_layer_distributed_eval is called in enabled gradient context. If forward hook returns an output that has gradients then that will lead to model output with gradients but torch sees it as unused in the computation graph. Or we need to iterate and set grads true in mod.parameters() if forward function is a model.

I think I saw this issue if we use torch.no_grad() context as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! The fix I had in mind is this #647 . I think this should be sufficient for any layer method that needs grads with respect to a given layer input / output, and will also avoid enabling grads between the input and target layer if unnecessary. I don't think this should have any issues with unused tensors in the computation graph; I think that should only occur if the inputs provided to torch.autograd.grad were not used in the compute graph for the outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. I think the problem that I was seeing is that the output has no grad_fn.

saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)

In #647 it is setting the grads to the inputs w.r.t. which the gradients will be computed, right ? But I was thinking that if output has no grad_fn function then it can cause errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, yeah I think the error is generally something like this for the output:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I think for any tensor operation, if any input requires gradients, then the output requires gradients and has a corresponding grad_fn. In this case, if the inputs are not floating-point, then the inputs never actually require gradients, so if either the context is no_grad or the parameters don't require gradients, the output has no grad function, causing this issue. By requiring grads on the target layer, gradients should be enabled on the output as well, resolving this issue.

@facebook-github-bot
Copy link
Contributor

@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@NarineK merged this pull request in e31bf38.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants