Skip to content

Commit

Permalink
Add DataloaderAttribution (pytorch#1155)
Browse files Browse the repository at this point in the history
Summary:
Implement DataloaderAttribution, an attribution wrapper designed to wrap existing perturbation attr methods so that they can take `torch.utils.data.DataLoader` as the inputs. This enables attributions with respect to corpus-level metrics.

`perturbation_per_pass` is not supported in this diff. Will add separately in the next diff

Pull Request resolved: pytorch#1155

Reviewed By: vivekmig

Differential Revision: D46322232

fbshipit-source-id: fddb80aa0e2a5231bd7a62e568ed80e9ebc52a7f
  • Loading branch information
aobo-y authored and facebook-github-bot committed Jun 22, 2023
1 parent 3aed726 commit bc36a50
Show file tree
Hide file tree
Showing 2 changed files with 710 additions and 0 deletions.
377 changes: 377 additions & 0 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,377 @@
#!/usr/bin/env python3
from collections import defaultdict
from copy import copy
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
_format_baseline,
_format_feature_mask,
_format_output,
_format_tensor_into_tuples,
_get_max_feature_index,
_run_forward,
)
from captum._utils.typing import BaselineType
from captum.attr import FeatureAblation
from captum.attr._utils.attribution import Attribution
from torch import Tensor


class InputRole:
need_attr = 0
need_forward = 1
no_forward = 2


SUPPORTED_METHODS = {FeatureAblation}


# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
def _concat_tensors(accum, cur_output, _):
return cur_output if accum is None else torch.cat([accum, cur_output])


def _convert_output_shape(
unique_attr: Tensor,
attr_inputs: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
# unique_attr in shape(*output_dims, n_features)
output_dims = unique_attr.shape[:-1]
n_features = unique_attr.shape[-1]

attr = []

for inp, mask in zip(attr_inputs, feature_mask):
# input in shape(batch_size, *inp_feature_dims)
# attribute in shape(*output_dims, *inp_feature_dims)
attr_shape = (*output_dims, *inp.shape[1:])

expanded_feature_indices = mask.expand(attr_shape)

if len(inp.shape) > 2:
# exclude batch_size & last of actual value
extra_inp_dims = list(inp.shape[1:-1])

# unsqueeze unqiue_attr to have same number of dims as inp
# (*output_dims, 1..., 1, n_features)
# then broadcast to (*output_dims, *inp.shape[1:-1], n_features)
n_extra_dims = len(extra_inp_dims)
unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_features)
expanded_shape = (*output_dims, *extra_inp_dims, n_features)
expanded_unqiue_attr = unique_attr.reshape(unsqueezed_shape).expand(
expanded_shape
)
else:
expanded_unqiue_attr = unique_attr

# gather from (*output_dims, *inp.shape[1:-1], n_features)
inp_attr = torch.gather(expanded_unqiue_attr, -1, expanded_feature_indices)
attr.append(inp_attr)

return tuple(attr)


class DataloaderAttribution(Attribution):
r"""
Decorate a perturbation-based attribution algorthm to make it work with dataloaders.
The decorated instance will calculate attribution in the
same way as configured in the original attribution instance, but it will provide a
new "attribute" function which accepts a pytorch "dataloader" instance as the input
instead of a single batched "tensor" and supports customizing a "reduce" function to
determine how the forward return of each iteration of the dataloader should be
aggregated to single metric tensor to attribute. This would
be specially useful to attribute against some corpus-wise metrics,
e.g., Precision & Recall.
"""

def __init__(self, attr_method: Attribution) -> None:
r"""
Args:
attr_method (Attribution): An instance of any attribution algorithm
of type `Attribution`. E.g. Integrated Gradients,
Conductance or Saliency.
"""

assert (
type(attr_method) in SUPPORTED_METHODS
), f"DataloaderAttribution does not support {type(attr_method)}"

super().__init__(attr_method.forward_func)

# shallow copy is enough to avoid modifying original instance
self.attr_method = copy(attr_method)

self.attr_method.forward_func = self._forward_with_dataloader

def _forward_with_dataloader(
self,
perturbed_feature_indices,
dataloader: torch.utils.data.DataLoader,
input_roles: Tuple[int],
baselines: Tuple[Union[int, float, Tensor], ...],
feature_mask: Tuple[Tensor, ...],
reduce: Callable,
to_metric: Optional[Callable],
perturbation_per_pass: int,
show_progress: bool,
feature_idx_to_mask_idx: Dict[int, List[int]],
):
# a set of input/mask indices that need perturbation
perturbation_mask_indices = set()
for i, v in enumerate(perturbed_feature_indices[0].tolist()):
# value 0 means the feature has been perturbed
if not v:
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])

# create binary mask for inputs & set it to None if no perturbation is needed
perturbation_mask = tuple(
perturbed_feature_indices[0][mask_elem]
if i in perturbation_mask_indices
else None
for i, mask_elem in enumerate(feature_mask)
)

accum = None
for inputs in dataloader:
perturbed_inputs = []
attr_inp_count = 0

for inp, role in zip(inputs, input_roles):
if role != InputRole.need_attr:
perturbed_inputs.append(inp)
continue

pert_mask = perturbation_mask[attr_inp_count]

# no perturbation is needed for this input
if pert_mask is None:
perturbed_inputs.append(inp)
else:
baseline = baselines[attr_inp_count]

perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)

attr_inp_count += 1

perturbed_inputs = tuple(perturbed_inputs)

# due to explicitly defined roles
# we can keep inputs in their original order regardless of if they need attr
# instead of using additional_forward_inputs to always appeend in the end
forward_inputs = tuple(
_
for _, role in zip(perturbed_inputs, input_roles)
if role != InputRole.no_forward
)

output = _run_forward(
self.forward_func,
forward_inputs,
)

accum = reduce(accum, output, perturbed_inputs)

if to_metric is not None:
return to_metric(accum)

return accum

def attribute(
self,
dataloader: torch.utils.data.DataLoader,
input_roles: Optional[Tuple[int, ...]] = None,
baselines: BaselineType = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
reduce: Optional[Callable] = None,
to_metric: Optional[Callable] = None,
perturbation_per_pass: int = -1,
show_progress: bool = False,
return_input_shape: bool = True,
) -> Union[Tensor, Tuple[Tensor, ...]]:
r"""
Args:
dataloader (torch.Dataloader): the dataloader to attribute, which should
return a tuple of consistant size for every iteration
input_roles (tuple[int, ...], optional): a tuple of integers to define the
role of each element returned from the dataloader. It should
have the same size as the return of the dataloader.
The available roles are:
0: the element is passed to forward_func and needs attribution.
It must be a tensor.
1: the element is excluded for forward_func. A typical example
is the label.
2: the element is passed to forward_func but does not need
attribution. Like additional_forward_args
baselines (Union[Tensor, tuple[Tensor, ...]], optional): same as the
baseline in attribute. The same baseline will be
applied to the entire dataloader. The first dimension is
assumed to be batch size and it must be 1. Baselines should only
be specififed for the dataloader's returns that need
attribution (role = 0)
feature_mask (Union[Tensor, tuple[Tensor, ...]], optional): same as the
feature_mask in attribute. The same feature_mask will be
applied to the entire dataloader. The first dimension is
assumed to be batch size and it must be 1. Mask should only
be specififed for the dataloader's returns that need
attribution (role = 0)
reduce (Callable, optional): a function to accumulate the forward output of
each iteration of the dataloader. The function signature is:
``reduce(accum, current_output, current_inputs) -> accum``,
where:
accum (Any): accumulated states, can be any type
current_output (Tensor): current output tensor from forward_func
current_inputs (tuple[Any,...]): current inputs from dataloader
to_metric (Callable, optional): an optional function to further convert
accumulated results through "reduce" after tranversing the whole
dataloader to a single tensor of metrics to calculate
attribution against. The function signature is:
``to_metric(accum) -> metric``, where:
accum (Any): accumulated state from reduce function
metric (Tensor): final result to be attributed, must be a Tensor
If None, will directly attribute w.r.t the reduced ``accum``
perturbation_per_pass (int, optional
concurrently in each traverse of the dataloader. The number of
traverses is ceil(n_perturbations / perturbation_per_pass).
The parameter offers a control of the trade-off between memory
and efficiency. If the dataloader involves slow operations like
remote request or file I/O, multiple traversals can be
inefficient. Each perturbation needs to store its accumulated
outputs of the reduce function until the end of the data
traverse. If the value is -1, all perturbations are concurrent
in a single traverse.
return_input_shape (bool, optional): if True, returns the attribution
following the input shapes given by the dataloader.
Otherwise, returns a single tensor for the attributions of
all the features, where the last dimension
is the number of features.
Returns:
**attributions** :
- **attributions** (*Tensor* or *tuple[Tensor, ...]*):
Attribution with respect to each input feature.
if return_input_shape is True, attributions will be
the same size as the given dataloader's returns that need
attribution (role = 0), with each value
providing the attribution of the corresponding input index.
If a single tensor is provided as inputs, a single tensor is
returned. If a tuple is provided for inputs, a tuple of
corresponding sized tensors is returned.
If return_input_shape is False, a single tensor is returned
where each index of the last dimension represents a feature
"""
inputs = next(iter(dataloader))
is_inputs_tuple = True

if type(inputs) is list:
# support list as it is a common return type for dataloader in torch
inputs = tuple(inputs)
elif type(inputs) is not tuple:
is_inputs_tuple = False
inputs = _format_tensor_into_tuples(inputs)

if input_roles:
assert len(input_roles) == len(inputs), (
"input_roles must have the same size as the return of the dataloader,",
f"length of input_roles is {len(input_roles)} ",
f"whereas the length of dataloader return is {len(inputs)}",
)

assert any(role == InputRole.need_attr for role in input_roles), (
"input_roles must contain at least one element need attribution"
f"({InputRole.need_attr}), received input_roles: {input_roles}"
)
else:
# by default, assume every element in the dataloader needs attribution
input_roles = tuple(InputRole.need_attr for _ in inputs)

attr_inputs = tuple(
inp for role, inp in zip(input_roles, inputs) if role == InputRole.need_attr
)

baselines = _format_baseline(baselines, attr_inputs)

assert len(attr_inputs) == len(baselines), (
"Baselines must have the same size as the return of the dataloader ",
"that need attribution",
f"length of baseline is {len(baselines)} ",
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
)

for i, baseline in enumerate(baselines):
if isinstance(baseline, Tensor):
assert baseline.size(0) == 1, (
"If the baseline is a tensor, "
"its 1st dim of baseline must be 1 so it can be broadacasted to "
"any batch of the dataloader:"
f"baselines[{i}].shape = {baseline.shape}"
)

feature_mask = _format_feature_mask(feature_mask, attr_inputs)

assert len(attr_inputs) == len(feature_mask), (
"Feature mask must have the same size as the return of the dataloader ",
"that need attribution",
f"length of feature_mask is {len(feature_mask)} ",
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
)

for i, each_mask in enumerate(feature_mask):
assert each_mask.size(0) == 1, (
"The 1st dim of feature_mask must be 1 so it can be broadcasted to "
"any batch of the dataloader:"
f"feature_mask[{i}].shape = {each_mask.shape}"
)

# map to retrieve masks contain a given feature index
feature_idx_to_mask_idx = defaultdict(list)
for i, mask in enumerate(feature_mask):
unqiue_feature_indices = torch.unique(mask).tolist()
for feature_idx in unqiue_feature_indices:
feature_idx_to_mask_idx[feature_idx].append(i)

max_feature_idx = _get_max_feature_index(feature_mask)
n_features = max_feature_idx + 1

if reduce is None:
reduce = _concat_tensors

# onehot tensor for feature indices
feature_indices = torch.ones((1, n_features), device=attr_inputs[0].device)

# unique_attr in shape(*output_dims, n_features)
unique_attr = self.attr_method.attribute(
feature_indices,
additional_forward_args=(
dataloader,
input_roles,
baselines,
feature_mask,
reduce,
to_metric,
perturbation_per_pass,
show_progress,
feature_idx_to_mask_idx,
),
)

if not return_input_shape:
return unique_attr
else:
attr = _convert_output_shape(
unique_attr,
attr_inputs,
feature_mask,
)

return _format_output(is_inputs_tuple, attr)
Loading

0 comments on commit bc36a50

Please sign in to comment.