From 1f2a97b42fff63dc5d43b5d2d4563d2661595ce9 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 20:30:55 -0700 Subject: [PATCH] Fix pyre errors in GradientShap (#1394) Summary: Initial work on fixing Pyre errors in GradientSHAP Differential Revision: D64677343 --- captum/attr/_core/gradient_shap.py | 82 +++++++++--------------------- 1 file changed, 25 insertions(+), 57 deletions(-) diff --git a/captum/attr/_core/gradient_shap.py b/captum/attr/_core/gradient_shap.py index c179633ae..e7a67140f 100644 --- a/captum/attr/_core/gradient_shap.py +++ b/captum/attr/_core/gradient_shap.py @@ -2,14 +2,13 @@ # pyre-strict import typing -from typing import Any, Callable, Tuple, Union +from typing import Any, Callable, Literal, Tuple, Union import numpy as np import torch from captum._utils.common import _is_tuple from captum._utils.typing import ( BaselineType, - Literal, TargetType, Tensor, TensorOrTupleOfTensorsGeneric, @@ -57,8 +56,9 @@ class GradientShap(GradientAttribution): samples and compute the expectation (smoothgrad). """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None: + def __init__( + self, forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True + ) -> None: r""" Args: @@ -82,8 +82,6 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N self._multiply_by_inputs = multiply_by_inputs @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `84`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -93,17 +91,12 @@ def attribute( n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `99`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -113,10 +106,7 @@ def attribute( n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -132,7 +122,7 @@ def attribute( n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] @@ -265,20 +255,10 @@ def attribute( """ # since `baselines` is a distribution, we can generate it using a function # rather than passing it as an input argument - # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., - # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, - # ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - baselines = _format_callable_baseline(baselines, inputs) - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> - # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no - # attribute `__getitem__`. - assert isinstance(baselines[0], torch.Tensor), ( + formatted_baselines = _format_callable_baseline(baselines, inputs) + assert isinstance(formatted_baselines[0], torch.Tensor), ( "Baselines distribution has to be provided in a form " - # pyre-fixme[16]: Item `Callable` of `Union[(...) -> - # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no - # attribute `__getitem__`. - "of a torch.Tensor {}.".format(baselines[0]) + "of a torch.Tensor {}.".format(formatted_baselines[0]) ) input_min_baseline_x_grad = InputBaselineXGradient( @@ -296,7 +276,7 @@ def attribute( nt_samples=n_samples, stdevs=stdevs, draw_baseline_from_distrib=True, - baselines=baselines, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, return_convergence_delta=return_convergence_delta, @@ -322,8 +302,11 @@ def multiplies_by_inputs(self) -> bool: class InputBaselineXGradient(GradientAttribution): - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None: + _multiply_by_inputs: bool + + def __init__( + self, forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True + ) -> None: r""" Args: @@ -345,37 +328,26 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N """ GradientAttribution.__init__(self, forward_func) - # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = multiply_by_inputs @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `318`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `329`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -385,29 +357,25 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] ]: # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs, baselines = _format_input_baseline(inputs, baselines) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) rand_coefficient = torch.tensor( - np.random.uniform(0.0, 1.0, inputs[0].shape[0]), - device=inputs[0].device, - dtype=inputs[0].dtype, + np.random.uniform(0.0, 1.0, inputs_tuple[0].shape[0]), + device=inputs_tuple[0].device, + dtype=inputs_tuple[0].dtype, ) input_baseline_scaled = tuple( _scale_input(input, baseline, rand_coefficient) - for input, baseline in zip(inputs, baselines) + for input, baseline in zip(inputs_tuple, baselines) ) grads = self.gradient_func( self.forward_func, input_baseline_scaled, target, additional_forward_args @@ -415,7 +383,7 @@ def attribute( # type: ignore if self.multiplies_by_inputs: input_baseline_diffs = tuple( - input - baseline for input, baseline in zip(inputs, baselines) + input - baseline for input, baseline in zip(inputs_tuple, baselines) ) attributions = tuple( input_baseline_diff * grad