Skip to content

Commit

Permalink
Type Fixes in Robust Metrics (#707)
Browse files Browse the repository at this point in the history
Summary:
Fixes Mypy type checking issues with attack metrics to resolve CircleCI issues

Pull Request resolved: #707

Reviewed By: NarineK

Differential Revision: D29552315

Pulled By: vivekmig

fbshipit-source-id: ba44d7e4121df30d26ac9e0bc796614ac726a9ed
  • Loading branch information
vivekmig authored and facebook-github-bot committed Jul 6, 2021
1 parent 67a3ddc commit 268035f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 41 deletions.
13 changes: 13 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ def _format_input(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> Tuple[Tensor, ..
return _format_tensor_into_tuples(inputs)


def _format_float_or_tensor_into_tuples(
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
) -> Tuple[Union[float, Tensor], ...]:
if not isinstance(inputs, tuple):
assert isinstance(
inputs, (torch.Tensor, float)
), "`inputs` must have type float or torch.Tensor but {} found: ".format(
type(inputs)
)
inputs = (inputs,)
return inputs


@overload
def _format_additional_forward_args(additional_forward_args: None) -> None:
...
Expand Down
6 changes: 3 additions & 3 deletions captum/attr/_utils/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _copy_stats(self):

return copy.deepcopy(self._stats)

def update(self, x: Union[Tensor, Tuple[Tensor, ...]]):
def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]):
r"""
Calls `update` on each `Stat` object within the summarizer
Expand All @@ -57,9 +57,9 @@ def update(self, x: Union[Tensor, Tuple[Tensor, ...]]):
# we want input to be consistently a single input or a tuple
assert not (self._is_inputs_tuple ^ isinstance(x, tuple))

from captum._utils.common import _format_tensor_into_tuples
from captum._utils.common import _format_float_or_tensor_into_tuples

x = _format_tensor_into_tuples(x)
x = _format_float_or_tensor_into_tuples(x)

for i, inp in enumerate(x):
if i >= len(self._summarizers):
Expand Down
78 changes: 47 additions & 31 deletions captum/robust/_core/metrics/attack_comparator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
#!/usr/bin/env python3
import warnings
from collections import namedtuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

from torch import Tensor

Expand All @@ -15,6 +27,10 @@

ORIGINAL_KEY = "Original"

MetricResultType = TypeVar(
"MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...]
)


class AttackInfo(NamedTuple):
attack_fn: Union[Perturbation, Callable]
Expand All @@ -33,7 +49,7 @@ def agg_metric(inp):
return inp


class AttackComparator:
class AttackComparator(Generic[MetricResultType]):
r"""
Allows measuring model robustness for a given attack or set of attacks. This class
can be used with any metric(s) as well as any set of attacks, either based on
Expand All @@ -44,7 +60,7 @@ class AttackComparator:
def __init__(
self,
forward_func: Callable,
metric: Callable[..., Union[float, Tensor, Tuple[Union[float, Tensor], ...]]],
metric: Callable[..., MetricResultType],
preproc_fn: Callable = None,
) -> None:
r"""
Expand Down Expand Up @@ -74,10 +90,10 @@ def model_metric(model_out: Tensor, **kwargs: Any)
additional_forward_args provided to evaluate.
"""
self.forward_func = forward_func
self.metric = metric
self.metric: Callable = metric
self.preproc_fn = preproc_fn
self.attacks = {}
self.summary_results = {}
self.attacks: Dict[str, AttackInfo] = {}
self.summary_results: Dict[str, Summarizer] = {}
self.metric_aggregator = agg_metric
self.batch_stats = [Mean, Min, Max]
self.aggregate_stats = [Mean]
Expand Down Expand Up @@ -148,7 +164,7 @@ def add_attack(

def _format_summary(
self, summary: Union[Dict, List[Dict]]
) -> Dict[str, Union[float, Tuple[float, ...]]]:
) -> Dict[str, MetricResultType]:
r"""
This method reformats a given summary; particularly for tuples,
the Summarizer's summary format is a list of dictionaries,
Expand All @@ -159,12 +175,12 @@ def _format_summary(
if isinstance(summary, dict):
return summary
else:
summary_dict = {}
summary_dict: Dict[str, Tuple] = {}
for key in summary[0]:
summary_dict[key] = tuple(s[key] for s in summary)
if self.out_format:
summary_dict[key] = self.out_format(*summary_dict[key])
return summary_dict
return summary_dict # type: ignore

def _update_out_format(
self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
Expand All @@ -174,7 +190,9 @@ def _update_out_format(
and isinstance(out_metric, tuple)
and hasattr(out_metric, "_fields")
):
self.out_format = namedtuple(type(out_metric).__name__, out_metric._fields)
self.out_format = namedtuple( # type: ignore
type(out_metric).__name__, cast(NamedTuple, out_metric)._fields
)

def _evaluate_batch(
self,
Expand Down Expand Up @@ -212,13 +230,10 @@ def _evaluate_batch(
def evaluate(
self,
inputs: Any,
additional_forward_args: Optional[Tuple] = None,
additional_forward_args: Any = None,
perturbations_per_eval: int = 1,
**kwargs,
) -> Dict[
str,
Union[Tensor, Tuple[Tensor, ...], Dict[str, Union[Tensor, Tuple[Tensor, ...]]]],
]:
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
r"""
Evaluate model and attack performance on provided inputs
Expand Down Expand Up @@ -385,45 +400,44 @@ def _check_and_evaluate(input_list, key_list):

def _parse_and_update_results(
self, batch_summarizers: Dict[str, Summarizer]
) -> Dict[
str, Union[float, Tuple[float, ...], Dict[str, Union[float, Tuple[float, ...]]]]
]:
results = {
ORIGINAL_KEY: self._format_summary(batch_summarizers[ORIGINAL_KEY].summary)[
"mean"
]
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = {
ORIGINAL_KEY: self._format_summary(
cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary)
)["mean"]
}
self.summary_results[ORIGINAL_KEY].update(
self.metric_aggregator(results[ORIGINAL_KEY])
)
for attack_key in self.attacks:
attack = self.attacks[attack_key]
results[attack.name] = self._format_summary(
batch_summarizers[attack.name].summary
attack_results = self._format_summary(
cast(Union[Dict, List], batch_summarizers[attack.name].summary)
)
results[attack.name] = attack_results

if len(results[attack.name]) == 1:
key = next(iter(results[attack.name]))
if len(attack_results) == 1:
key = next(iter(attack_results))
if attack.name not in self.summary_results:
self.summary_results[attack.name] = Summarizer(
[stat() for stat in self.aggregate_stats]
)
self.summary_results[attack.name].update(
self.metric_aggregator(results[attack.name][key])
self.metric_aggregator(attack_results[key])
)
else:
for key in results[attack.name]:
for key in attack_results:
summary_key = f"{attack.name} {key.title()} Attempt"
if summary_key not in self.summary_results:
self.summary_results[summary_key] = Summarizer(
[stat() for stat in self.aggregate_stats]
)
self.summary_results[summary_key].update(
self.metric_aggregator(results[attack.name][key])
self.metric_aggregator(attack_results[key])
)
return results

def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
def summary(self) -> Dict[str, Dict[str, MetricResultType]]:
r"""
Returns average results over all previous batches evaluated.
Expand All @@ -440,7 +454,9 @@ def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
per batch.
"""
return {
key: self._format_summary(self.summary_results[key].summary)
key: self._format_summary(
cast(Union[Dict, List], self.summary_results[key].summary)
)
for key in self.summary_results
}

Expand Down
10 changes: 6 additions & 4 deletions captum/robust/_core/metrics/min_param_perturbation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import math
from enum import Enum
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -136,7 +136,9 @@ def correct_fn(model_out: Tensor, **kwargs: Any) -> bool
self.num_attempts = num_attempts
self.preproc_fn = preproc_fn
self.apply_before_preproc = apply_before_preproc
self.correct_fn = correct_fn if correct_fn is not None else default_correct_fn
self.correct_fn = cast(
Callable, correct_fn if correct_fn is not None else default_correct_fn
)

assert (
mode.upper() in MinParamPerturbationMode.__members__
Expand All @@ -147,9 +149,9 @@ def _evaluate_batch(
self,
input_list: List,
additional_forward_args: Any,
correct_fn_kwargs: Dict[str, Any],
correct_fn_kwargs: Optional[Dict[str, Any]],
target: TargetType,
) -> None:
) -> Optional[int]:
if additional_forward_args is None:
additional_forward_args = ()

Expand Down
6 changes: 3 additions & 3 deletions tests/robust/test_min_param_perturbation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
from typing import List
from typing import List, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_minimal_pert_basic_linear(self) -> None:
target_inp, pert = minimal_pert.evaluate(
inp, target=0, attack_kwargs={"ind": 0}
)
self.assertAlmostEqual(pert, 2.0)
self.assertAlmostEqual(cast(float, pert), 2.0)
assertTensorAlmostEqual(
self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]])
)
Expand All @@ -79,7 +79,7 @@ def test_minimal_pert_basic_binary(self) -> None:
attack_kwargs={"ind": 0},
perturbations_per_eval=10,
)
self.assertAlmostEqual(pert, 2.0)
self.assertAlmostEqual(cast(float, pert), 2.0)
assertTensorAlmostEqual(
self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]])
)
Expand Down

0 comments on commit 268035f

Please sign in to comment.