Skip to content

Commit

Permalink
Fix pyre errors in Guided Backprop (pytorch#1395)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Guided Backprop

Reviewed By: jjuncho

Differential Revision: D64677346
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 9c23ac0 commit eec1eb6
Showing 1 changed file with 7 additions and 18 deletions.
25 changes: 7 additions & 18 deletions captum/attr/_core/guided_backprop_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Computes attribution by overriding relu gradients. Based on constructor
Expand All @@ -58,16 +57,10 @@ def attribute(

# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

# set hooks for overriding ReLU gradients
warnings.warn(
Expand All @@ -79,14 +72,12 @@ def attribute(
self.model.apply(self._register_hooks)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)
finally:
self._remove_hooks()

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, gradients)
Expand Down Expand Up @@ -155,8 +146,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -265,8 +255,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down

0 comments on commit eec1eb6

Please sign in to comment.