Skip to content

Commit

Permalink
Fix pyre errors in NoiseTunnel (pytorch#1402)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Noise Tunnel

Differential Revision: D64677341
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 21, 2024
1 parent 56f466c commit 49697c1
Showing 1 changed file with 19 additions and 30 deletions.
49 changes: 19 additions & 30 deletions captum/attr/_core/noise_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -27,8 +27,7 @@ class NoiseTunnelType(Enum):
vargrad = 3


# pyre-fixme[5]: Global expression must be annotated.
SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys())
SUPPORTED_NOISE_TUNNEL_TYPES: List[str] = list(NoiseTunnelType.__members__.keys())


class NoiseTunnel(Attribution):
Expand Down Expand Up @@ -58,6 +57,10 @@ class NoiseTunnel(Attribution):
It is assumed that the batch size is the first dimension of input tensors.
"""

is_delta_supported: bool
_multiply_by_inputs: bool
is_gradient_method: bool

def __init__(self, attribution_method: Attribution) -> None:
r"""
Args:
Expand All @@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None:
Conductance or Saliency.
"""
self.attribution_method = attribution_method
# pyre-fixme[4]: Attribute must be annotated.
self.is_delta_supported = self.attribution_method.has_convergence_delta()
# pyre-fixme[4]: Attribute must be annotated.
self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs
# pyre-fixme[4]: Attribute must be annotated.
self.is_gradient_method = isinstance(
self.attribution_method, GradientAttribution
)
Attribution.__init__(self, self.attribution_method.forward_func)

@property
# pyre-fixme[3]: Return type must be annotated.
def multiplies_by_inputs(self):
def multiplies_by_inputs(self) -> bool:
return self._multiply_by_inputs

@log_usage()
Expand Down Expand Up @@ -205,9 +204,10 @@ def attribute(
nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib
)

sum_attributions: List[Union[None, Tensor]] = []
sum_attributions_sq: List[Union[None, Tensor]] = []
sum_attributions: Sequence[Union[None, Tensor]] = []
sum_attributions_sq: Sequence[Union[None, Tensor]] = []
delta_partial_list: List[Tensor] = []
is_attrib_tuple = is_inputs_tuple

for _ in range(nt_samples_partition):
inputs_with_noise = self._add_noise_to_inputs(
Expand All @@ -225,11 +225,7 @@ def attribute(
)

if len(sum_attributions) == 0:
# pyre-fixme[9]: sum_attributions has type
# `List[Optional[Tensor]]`; used as `List[None]`.
sum_attributions = [None] * len(attributions_partial)
# pyre-fixme[9]: sum_attributions_sq has type
# `List[Optional[Tensor]]`; used as `List[None]`.
sum_attributions_sq = [None] * len(attributions_partial)

self._update_partial_attribution_and_delta(
Expand Down Expand Up @@ -297,7 +293,6 @@ def attribute(

return self._apply_checks_and_return_attributions(
attributions,
# pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined.
is_attrib_tuple,
return_convergence_delta,
delta,
Expand Down Expand Up @@ -348,9 +343,7 @@ def _add_noise_to_input(
bsz = input.shape[0]

# expand input size by the number of drawn samples
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
# and `Size`.
input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:]
input_expanded_size = (bsz * nt_samples_partition,) + tuple(input.shape[1:])

# expand stdev for the shape of the input and number of drawn samples
stdev_expanded = torch.tensor(stdev, device=input.device).repeat(
Expand All @@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq(
bsz = attribution.shape[0] // nt_samples_batch_size_inter
attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter))
if len(attribution.shape) > 1:
# pyre-fixme[22]: The cast is redundant.
attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:]))
attribution_shape += tuple(attribution.shape[1:])

attribution = attribution.view(attribution_shape)
current_attribution_sum = attribution.sum(dim=1, keepdim=False)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False)
current_attribution_sq = torch.sum(
torch.pow(attribution, 2), dim=1, keepdim=False
)

sum_attribution[i] = (
current_attribution_sum
Expand All @@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq(
def _compute_partial_attribution(
self,
inputs_with_noise_partition: Tuple[Tensor, ...],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
kwargs_partition: Any,
kwargs_partition: object,
is_inputs_tuple: bool,
return_convergence_delta: bool,
) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]:
Expand Down Expand Up @@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions(
) -> Union[
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
]:
# pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor,
# ...]]`, expected a tuple.
attributions = _format_output(is_attrib_tuple, attributions)
attributions_tuple = _format_output(is_attrib_tuple, attributions)

ret = (
(attributions, cast(Tensor, delta))
(attributions_tuple, cast(Tensor, delta))
if self.is_delta_supported and return_convergence_delta
else attributions
else attributions_tuple
)
ret = cast(
# pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
Expand Down

0 comments on commit 49697c1

Please sign in to comment.