Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hook Removal #340

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,36 +305,39 @@ def attribute( # type: ignore
)

baselines = _tensorize_baseline(inputs, baselines)
main_model_pre_hooks = self._hook_main_model()
main_model_hooks = []
try:
main_model_hooks = self._hook_main_model()

self.model.apply(self._register_hooks)
self.model.apply(self._register_hooks)

additional_forward_args = _format_additional_forward_args(
additional_forward_args
)

expanded_target = _expand_target(
target, 2, expansion_type=ExpansionTypes.repeat
)

wrapped_forward_func = self._construct_forward_func(
self.model, (inputs, baselines), expanded_target, additional_forward_args
)
gradients = self.gradient_func(wrapped_forward_func, inputs)
if custom_attribution_func is None:
attributions = tuple(
(input - baseline) * gradient
for input, baseline, gradient in zip(inputs, baselines, gradients)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
else:
attributions = _call_custom_attribution_func(
custom_attribution_func, gradients, inputs, baselines

expanded_target = _expand_target(
target, 2, expansion_type=ExpansionTypes.repeat
)
# remove hooks from all activations
for hook in main_model_pre_hooks:
hook.remove()

self._remove_hooks()
wrapped_forward_func = self._construct_forward_func(
self.model,
(inputs, baselines),
expanded_target,
additional_forward_args,
)
gradients = self.gradient_func(wrapped_forward_func, inputs)
if custom_attribution_func is None:
attributions = tuple(
(input - baseline) * gradient
for input, baseline, gradient in zip(inputs, baselines, gradients)
)
else:
attributions = _call_custom_attribution_func(
custom_attribution_func, gradients, inputs, baselines
)
finally:
# Even if any error is raised, remove all hooks before raising
self._remove_hooks(main_model_hooks)

vivekmig marked this conversation as resolved.
Show resolved Hide resolved
undo_gradient_requirements(inputs, gradient_mask)
return _compute_conv_delta_and_format_attrs(
Expand Down Expand Up @@ -501,7 +504,9 @@ def _register_hooks(self, module: Module) -> None:
self.forward_handles.append(pre_forward_handle)
self.backward_handles.append(backward_handle)

def _remove_hooks(self) -> None:
def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None:
for handle in extra_hooks_to_remove:
handle.remove()
for forward_handle in self.forward_handles:
forward_handle.remove()
for backward_handle in self.backward_handles:
Expand Down
15 changes: 7 additions & 8 deletions captum/attr/_core/guided_backprop_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ def attribute(
"Setting backward hooks on ReLU activations."
"The hooks will be removed after the attribution is finished"
)
try:
self.model.apply(self._register_hooks)

self.model.apply(self._register_hooks)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)

# remove set hooks
self._remove_hooks()
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)
finally:
self._remove_hooks()

undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, gradients)
Expand Down
84 changes: 44 additions & 40 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,52 +276,56 @@ def attribute(

baselines = _tensorize_baseline(inputs, baselines)

main_model_hooks = self._hook_main_model()
main_model_hooks = []
try:
main_model_hooks = self._hook_main_model()

self.model.apply(self._register_hooks)
self.model.apply(self._register_hooks)

additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
expanded_target = _expand_target(
target, 2, expansion_type=ExpansionTypes.repeat
)
wrapped_forward_func = self._construct_forward_func(
self.model, (inputs, baselines), expanded_target, additional_forward_args,
)

def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
if isinstance(out, Tensor):
return out.chunk(2)
return tuple(out_sub.chunk(2) for out_sub in out)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
expanded_target = _expand_target(
target, 2, expansion_type=ExpansionTypes.repeat
)
wrapped_forward_func = self._construct_forward_func(
self.model,
(inputs, baselines),
expanded_target,
additional_forward_args,
)

(gradients, attrs, is_layer_tuple) = compute_layer_gradients_and_eval(
wrapped_forward_func,
self.layer,
inputs,
attribute_to_layer_input=attribute_to_layer_input,
output_fn=lambda out: chunk_output_fn(out),
)
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
if isinstance(out, Tensor):
return out.chunk(2)
return tuple(out_sub.chunk(2) for out_sub in out)

(gradients, attrs, is_layer_tuple) = compute_layer_gradients_and_eval(
wrapped_forward_func,
self.layer,
inputs,
attribute_to_layer_input=attribute_to_layer_input,
output_fn=lambda out: chunk_output_fn(out),
)

attr_inputs = tuple(map(lambda attr: attr[0], attrs))
attr_baselines = tuple(map(lambda attr: attr[1], attrs))
gradients = tuple(map(lambda grad: grad[0], gradients))
attr_inputs = tuple(map(lambda attr: attr[0], attrs))
attr_baselines = tuple(map(lambda attr: attr[1], attrs))
gradients = tuple(map(lambda grad: grad[0], gradients))

if custom_attribution_func is None:
attributions = tuple(
(input - baseline) * gradient
for input, baseline, gradient in zip(
attr_inputs, attr_baselines, gradients
if custom_attribution_func is None:
attributions = tuple(
(input - baseline) * gradient
for input, baseline, gradient in zip(
attr_inputs, attr_baselines, gradients
)
)
)
else:
attributions = _call_custom_attribution_func(
custom_attribution_func, gradients, attr_inputs, attr_baselines
)
# remove hooks from all activations
self._remove_hooks()
for hook in main_model_hooks:
hook.remove()
else:
attributions = _call_custom_attribution_func(
custom_attribution_func, gradients, attr_inputs, attr_baselines
)
finally:
# remove hooks from all activations
self._remove_hooks(main_model_hooks)

undo_gradient_requirements(inputs, gradient_mask)
return _compute_conv_delta_and_format_attrs(
Expand Down
16 changes: 10 additions & 6 deletions captum/attr/_core/layer/layer_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,16 @@ def forward_hook(module, inp, out=None):
return all_layer_inputs[device][0]
return all_layer_inputs[device]

if attribute_to_layer_input:
hook = self.layer.register_forward_pre_hook(forward_hook)
else:
hook = self.layer.register_forward_hook(forward_hook)
eval = _run_forward(self.forward_func, original_inputs, target=target)
hook.remove()
hook = None
try:
if attribute_to_layer_input:
hook = self.layer.register_forward_pre_hook(forward_hook)
else:
hook = self.layer.register_forward_hook(forward_hook)
eval = _run_forward(self.forward_func, original_inputs, target=target)
finally:
if hook is not None:
hook.remove()
return eval

with torch.no_grad():
Expand Down
21 changes: 13 additions & 8 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,20 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
return scattered_inputs_dict[device]
return scattered_inputs_dict[device][0]

if attribute_to_layer_input:
hook = self.layer.register_forward_pre_hook(layer_forward_hook)
else:
hook = self.layer.register_forward_hook(layer_forward_hook)
hook = None
try:
if attribute_to_layer_input:
hook = self.layer.register_forward_pre_hook(layer_forward_hook)
else:
hook = self.layer.register_forward_hook(layer_forward_hook)

output = _run_forward(
self.forward_func, tuple(), target_ind, additional_forward_args
)
finally:
if hook is not None:
hook.remove()

output = _run_forward(
self.forward_func, tuple(), target_ind, additional_forward_args
)
hook.remove()
assert output[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
Expand Down
26 changes: 15 additions & 11 deletions captum/attr/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,21 @@ def forward_hook(module, inp, out=None):
eval_tsr.clone() for eval_tsr in eval_tsrs
)

if attribute_to_layer_input:
hook = layer.register_forward_pre_hook(forward_hook)
else:
hook = layer.register_forward_hook(forward_hook)
output = _run_forward(
forward_fn,
inputs,
target=target_ind,
additional_forward_args=additional_forward_args,
)
hook.remove()
hook = None
try:
if attribute_to_layer_input:
hook = layer.register_forward_pre_hook(forward_hook)
else:
hook = layer.register_forward_hook(forward_hook)
output = _run_forward(
forward_fn,
inputs,
target=target_ind,
additional_forward_args=additional_forward_args,
)
finally:
if hook is not None:
hook.remove()

if len(saved_layer) == 0:
raise AssertionError("Forward hook did not obtain any outputs for given layer")
Expand Down
Loading