Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaf Warning Fix #597

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ def apply_gradient_requirements(inputs: Tuple[Tensor, ...]) -> List[bool]:
"required_grads has been set automatically." % index
)
input.requires_grad_()
if input.grad is not None:
if torch.sum(torch.abs(input.grad)).item() > 1e-7:
warnings.warn(
"Input Tensor %d had a non-zero gradient tensor, "
"which is being reset to 0." % index
)
input.grad.zero_()
return grad_required


Expand All @@ -84,9 +77,6 @@ def undo_gradient_requirements(
), "Input tuple length should match gradient mask."
for index, input in enumerate(inputs):
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
if input.grad is not None:
input.grad.detach_()
input.grad.zero_()
if not grad_required[index]:
input.requires_grad_(False)

Expand Down
7 changes: 7 additions & 0 deletions tests/attr/test_saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def test_saliency_classification_smoothgrad(self) -> None:
def test_saliency_classification_vargrad(self) -> None:
self._saliency_classification_assert(nt_type="vargrad")

def test_saliency_grad_unchanged(self) -> None:
model, inp, grads, add_args = _get_basic_config()
inp.grad = torch.randn_like(inp)
grad = inp.grad.detach().clone()
self._saliency_base_assert(model, inp, grads, add_args)
assertTensorTuplesAlmostEqual(self, inp.grad, grad, delta=0.0)

def _saliency_base_assert(
self,
model: Module,
Expand Down
17 changes: 7 additions & 10 deletions tests/utils/test_gradient.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#!/usr/bin/env python3

from typing import List, Tuple, cast
from typing import List, Tuple

import torch
from torch import Tensor

from captum._utils.gradient import (
apply_gradient_requirements,
Expand Down Expand Up @@ -32,10 +31,6 @@ def test_apply_gradient_reqs(self) -> None:
for i in range(len(test_tensor_tuple)):
self.assertTrue(test_tensor_tuple[i].requires_grad)
self.assertEqual(out_mask[i], initial_grads[i])
if test_tensor_tuple[i].grad is not None:
self.assertAlmostEqual(
torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0
)

def test_undo_gradient_reqs(self) -> None:
initial_grads = [False, True, False]
Expand All @@ -49,22 +44,24 @@ def test_undo_gradient_reqs(self) -> None:
undo_gradient_requirements(test_tensor_tuple, initial_grads)
for i in range(len(test_tensor_tuple)):
self.assertEqual(test_tensor_tuple[i].requires_grad, initial_grads[i])
if test_tensor_tuple[i].grad is not None:
self.assertAlmostEqual(
torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0
)

def test_gradient_basic(self) -> None:
model = BasicModel()
input = torch.tensor([[5.0]], requires_grad=True)
input.grad = torch.tensor([[9.0]])
grads = compute_gradients(model, input)[0]
assertArraysAlmostEqual(grads.squeeze(0).tolist(), [0.0], delta=0.01)
# Verify grad attribute is not altered
assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [9.0], delta=0.0)

def test_gradient_basic_2(self) -> None:
model = BasicModel()
input = torch.tensor([[-3.0]], requires_grad=True)
input.grad = torch.tensor([[14.0]])
grads = compute_gradients(model, input)[0]
assertArraysAlmostEqual(grads.squeeze(0).tolist(), [1.0], delta=0.01)
# Verify grad attribute is not altered
assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [14.0], delta=0.0)

def test_gradient_multiinput(self) -> None:
model = BasicModel6_MultiTensor()
Expand Down