Skip to content

Commit

Permalink
Enable perturbations_per_pass in DataloaderAttribution (#1158)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1158

Enable argument `perturbations_per_pass` in `DataloaderAttribution` to support multiple perturbation in a single traverse of the dataloader

Differential Revision: D46965996

fbshipit-source-id: 84da00fa958eef98e38a9b1c20a5b9c5c824187a
  • Loading branch information
aobo-y authored and facebook-github-bot committed Jun 23, 2023
1 parent b1a9830 commit ca4df29
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 71 deletions.
190 changes: 127 additions & 63 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
from collections import defaultdict
from copy import copy
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -32,11 +32,78 @@ def _concat_tensors(accum, cur_output, _):
return cur_output if accum is None else torch.cat([accum, cur_output])


def _create_perturbation_mask(
perturbed_feature_indices: Tensor, # 1D tensor of one-hot feature indices
feature_mask: Tuple[Tensor, ...],
feature_idx_to_mask_idx: Dict[int, List[int]],
) -> Tuple[Union[Tensor, None], ...]:
"""
Create binary mask for inputs based on perturbed one-hot feature indices
Use None if no perturbation is needed for the corresponding input
"""

# a set of input/mask indices that need perturbation
perturbation_mask_indices = set()
for i, v in enumerate(perturbed_feature_indices.tolist()):
# value 0 means the feature has been perturbed
if not v:
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])

# create binary mask for inputs & set it to None if no perturbation is needed
perturbation_mask = tuple(
perturbed_feature_indices[mask_elem] if i in perturbation_mask_indices else None
for i, mask_elem in enumerate(feature_mask)
)

return perturbation_mask


def _perturb_inputs(
inputs: Iterable[Any],
input_roles: Tuple[int],
baselines: Tuple[Union[int, float, Tensor], ...],
perturbation_mask: Tuple[Union[Tensor, None], ...],
) -> Tuple[Any, ...]:
"""
Perturb inputs based on perturbation mask and baselines
"""

perturbed_inputs = []
attr_inp_count = 0

for inp, role in zip(inputs, input_roles):
if role != InputRole.need_attr:
perturbed_inputs.append(inp)
continue

pert_mask = perturbation_mask[attr_inp_count]

# no perturbation is needed for this input
if pert_mask is None:
perturbed_inputs.append(inp)
else:
baseline = baselines[attr_inp_count]

perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)

attr_inp_count += 1

perturbed_inputs = tuple(perturbed_inputs)

return perturbed_inputs


def _convert_output_shape(
unique_attr: Tensor,
attr_inputs: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
"""
Convert the shape of a single tensor of unique feature attributionto
to match the shape of the inputs returned by dataloader
"""

# unique_attr in shape(*output_dims, n_features)
output_dims = unique_attr.shape[:-1]
n_features = unique_attr.shape[-1]
Expand Down Expand Up @@ -107,77 +174,73 @@ def __init__(self, attr_method: Attribution) -> None:

def _forward_with_dataloader(
self,
perturbed_feature_indices,
batched_perturbed_feature_indices: Tensor,
dataloader: torch.utils.data.DataLoader,
input_roles: Tuple[int],
baselines: Tuple[Union[int, float, Tensor], ...],
feature_mask: Tuple[Tensor, ...],
reduce: Callable,
to_metric: Optional[Callable],
perturbation_per_pass: int,
show_progress: bool,
feature_idx_to_mask_idx: Dict[int, List[int]],
):
# a set of input/mask indices that need perturbation
perturbation_mask_indices = set()
for i, v in enumerate(perturbed_feature_indices[0].tolist()):
# value 0 means the feature has been perturbed
if not v:
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])

# create binary mask for inputs & set it to None if no perturbation is needed
perturbation_mask = tuple(
perturbed_feature_indices[0][mask_elem]
if i in perturbation_mask_indices
else None
for i, mask_elem in enumerate(feature_mask)
)

accum = None
for inputs in dataloader:
perturbed_inputs = []
attr_inp_count = 0

for inp, role in zip(inputs, input_roles):
if role != InputRole.need_attr:
perturbed_inputs.append(inp)
continue

pert_mask = perturbation_mask[attr_inp_count]
"""
Wrapper of the original given forward_func to be used in the attribution method
It iterates over the dataloader with the given forward_func
"""

# no perturbation is needed for this input
if pert_mask is None:
perturbed_inputs.append(inp)
else:
baseline = baselines[attr_inp_count]
# batched_perturbed_feature_indices in shape(n_perturb, n_features)
# n_perturb is not always the same as perturb_per_pass if not enough perturb
perturbation_mask_list: List[Tuple[Union[Tensor, None], ...]] = [
_create_perturbation_mask(
perturbed_feature_indices,
feature_mask,
feature_idx_to_mask_idx,
)
for perturbed_feature_indices in batched_perturbed_feature_indices
]

perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)
# each perturbation needs an accum state
accum_states = [None for _ in range(len(perturbation_mask_list))]

attr_inp_count += 1
# tranverse the dataloader
for inputs in dataloader:
# for each batch read from the dataloader,
# apply every perturbation based on perturbations_per_pass
for i, perturbation_mask in enumerate(perturbation_mask_list):
perturbed_inputs = _perturb_inputs(
inputs, input_roles, baselines, perturbation_mask
)

perturbed_inputs = tuple(perturbed_inputs)
# due to explicitly defined roles
# we can keep inputs in their original order
# regardless of if they need attr
# instead of using additional_forward_inputs
forward_inputs = tuple(
_
for _, role in zip(perturbed_inputs, input_roles)
if role != InputRole.no_forward
)

# due to explicitly defined roles
# we can keep inputs in their original order regardless of if they need attr
# instead of using additional_forward_inputs to always appeend in the end
forward_inputs = tuple(
_
for _, role in zip(perturbed_inputs, input_roles)
if role != InputRole.no_forward
)
output = _run_forward(
self.forward_func,
forward_inputs,
)

output = _run_forward(
self.forward_func,
forward_inputs,
)
accum_states[i] = reduce(accum_states[i], output, perturbed_inputs)

accum = reduce(accum, output, perturbed_inputs)
accum_results = [to_metric(accum) if to_metric else accum for accum in accum_states]

if to_metric is not None:
return to_metric(accum)
assert all(type(r) is Tensor for r in accum_results), (
"Accumulated metrics for attribution must be a Tensor,"
f"received: {next(r for r in accum_results if type(r) is not Tensor)}"
)

return accum
# shape(n_perturb * output_dims[0], *output_dims[1:])
# the underneath attr method needs to support forward_func output's
# 1st dim to grow with perturb_per_eval
batched_accum = torch.stack(accum_results, dim=0)
return batched_accum

def attribute(
self,
Expand All @@ -187,7 +250,7 @@ def attribute(
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
reduce: Optional[Callable] = None,
to_metric: Optional[Callable] = None,
perturbation_per_pass: int = -1,
perturbations_per_pass: int = 1,
show_progress: bool = False,
return_input_shape: bool = True,
) -> Union[Tensor, Tuple[Tensor, ...]]:
Expand Down Expand Up @@ -240,16 +303,17 @@ def attribute(
metric (Tensor): final result to be attributed, must be a Tensor
If None, will directly attribute w.r.t the reduced ``accum``
perturbation_per_pass (int, optional
perturbations_per_pass (int, optional) the number perturbations to execute
concurrently in each traverse of the dataloader. The number of
traverses is ceil(n_perturbations / perturbation_per_pass).
The parameter offers a control of the trade-off between memory
traverses needed is
ceil(n_perturbations / perturbations_per_pass).
This arguement offers control of the trade-off between memory
and efficiency. If the dataloader involves slow operations like
remote request or file I/O, multiple traversals can be
inefficient. Each perturbation needs to store its accumulated
outputs of the reduce function until the end of the data
traverse. If the value is -1, all perturbations are concurrent
in a single traverse.
inefficient. On the other hand, each perturbation needs to
store its accumulated outputs of the reduce
function until the end of the data traverse.
return_input_shape (bool, optional): if True, returns the attribution
following the input shapes given by the dataloader.
Otherwise, returns a single tensor for the attributions of
Expand Down Expand Up @@ -352,14 +416,14 @@ def attribute(
# unique_attr in shape(*output_dims, n_features)
unique_attr = self.attr_method.attribute(
feature_indices,
perturbations_per_eval=perturbations_per_pass,
additional_forward_args=(
dataloader,
input_roles,
baselines,
feature_mask,
reduce,
to_metric,
perturbation_per_pass,
show_progress,
feature_idx_to_mask_idx,
),
Expand Down
56 changes: 48 additions & 8 deletions tests/attr/test_dataloader_attr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env fbpython
import math
from typing import cast
from unittest.mock import Mock, patch

import torch

Expand All @@ -12,6 +14,7 @@
BaseTest,
)
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset


def sum_forward(*inps):
Expand All @@ -29,7 +32,7 @@ def forward(self, *inps):
return self.linear(torch.cat(inps, dim=1))


mock_dataset = torch.utils.data.TensorDataset(
mock_dataset = TensorDataset(
# iD feature
torch.tensor(
[
Expand Down Expand Up @@ -74,7 +77,7 @@ def test_dl_attr(self, forward) -> None:
fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attributions = dl_fa.attribute(dataloader)

Expand Down Expand Up @@ -108,7 +111,7 @@ def test_dl_attr_with_mask(self, forward) -> None:
fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attributions = dl_fa.attribute(dataloader, feature_mask=masks)

Expand Down Expand Up @@ -140,7 +143,7 @@ def test_dl_attr_with_baseline(self, forward) -> None:
fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attributions = dl_fa.attribute(dataloader, baselines=baselines)

Expand Down Expand Up @@ -188,7 +191,7 @@ def to_metric(accum):
dl_fa = DataloaderAttribution(fa)

batch_size = 2
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=batch_size)
dataloader = DataLoader(mock_dataset, batch_size=batch_size)

dl_attribution = dl_fa.attribute(
dataloader,
Expand Down Expand Up @@ -243,7 +246,7 @@ def forward(*forward_inputs):
dl_fa = DataloaderAttribution(fa)

batch_size = 2
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=batch_size)
dataloader = DataLoader(mock_dataset, batch_size=batch_size)

dl_attributions = dl_fa.attribute(
dataloader,
Expand Down Expand Up @@ -282,7 +285,7 @@ def test_dl_attr_not_return_input_shape(self) -> None:
fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attribution = dl_fa.attribute(dataloader, return_input_shape=False)

Expand Down Expand Up @@ -320,7 +323,7 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attribution = dl_fa.attribute(
dataloader, feature_mask=masks, return_input_shape=False
Expand All @@ -331,3 +334,40 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
self.assertEqual(type(dl_attribution), Tensor)
dl_attribution = cast(Tensor, dl_attribution)
self.assertEqual(dl_attribution.shape, expected_attr_shape)

@parameterized.expand([(2,), (3,), (4,)])
def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
forward = sum_forward

fa = FeatureAblation(forward)
dl_fa = DataloaderAttribution(fa)

mock_dl_iter = Mock(wraps=DataLoader.__iter__)

with patch.object(DataLoader, "__iter__", lambda self: mock_dl_iter(self)):
dataloader = DataLoader(mock_dataset, batch_size=2)

dl_attributions = dl_fa.attribute(
dataloader, perturbations_per_pass=perturb_per_pass
)

n_features = 7
# 2 extra iter calls: get one input for format; get unperturbed output
n_iter_overhead = 2

self.assertEqual(
mock_dl_iter.call_count,
math.ceil(n_features / perturb_per_pass) + n_iter_overhead,
)

# default reduce of DataloaderAttribution works the same as concat all batches
attr_list = []
for batch in dataloader:
batch_attr = fa.attribute(tuple(batch))
attr_list.append(batch_attr)

expected_attr = tuple(
torch.cat(feature_attrs, dim=0) for feature_attrs in zip(*attr_list)
)

assertAttributionComparision(self, dl_attributions, expected_attr)

0 comments on commit ca4df29

Please sign in to comment.