Skip to content

Commit

Permalink
Allow multiple outputs for agg_mode=True in Feature Ablation (pytorch…
Browse files Browse the repository at this point in the history
…#425)

Summary:
Pull Request resolved: pytorch#425

## Description

What is aggregation output mode? It can be defined as:

When there is no 1:1 correspondence with the `num_examples` (`batch_size`) and the amount of outputs your model produces, i.e. the model output size does not grow in size as the `batch_size` becomes larger.

This allows for an arbitrary sized tensor to be output from the `forward_func` for feature ablation.

 ---
## Implementation Details

We assume `aggregation_output_mode` to be the case if: `perturbations_per_eval == 1` and [ `feature_mask is None` __or__ is of length 1 (i.e. associated to all inputs) ]. This is not perfect but for feature ablation the underlying logic is the same if there is a 1:1 correspondence (i.e. the model has `batch_size` outputs) and `agg_output_mode=True`

If `agg_output_mode == True`:
- Feature ablation will output a tensor of shape `1xOxF` where `O` is the number of output features and `F` is the number of input features under aggregation mode. Thus, if the model outputs a tensor > 2D the user must reshape it (as we treat the output as a 2D tensor in the implementation); thus it is recommended to only output a 2D tensor (i.e. the implementation allows for >2D).

If we are not in `agg_output_mode` we must ensure the number of elements is `n` (`batch_size`). If it is not, we output an error to the user. Here we could actually check if the element size is at least `n`, but for simplicity I am not doing this.

## Tests

Added tests to check for:

`agg_mode=True`:
- Incorrect feature mask (i.e. where `fm.shape[0] > 1`)
- Output a `Fx1` tensor where `F` is the number of features in the input
- The above but for a feature mask with the first two features treated as one feature
- Output a `2x3x5` constant tensor (not associated to outputs)
   - internally this will be interpreted as a `1x30` 2D tensor

`agg_mode=False`:
- Check there is exactly `n` outputs where `n == batch_size` => if not then check that we throw an exception (assertion error). **This already exists in `test_error_perturbations_per_eval_limit_batch_scalar`**

## Notes

I created a new function rather than modifying `_find_output_mode_and_verify`; as otherwise this breaks shapley value sampling. Will have to fix this in a separate PR.

Differential Revision: D22416476

fbshipit-source-id: 0d08ca990a1e999339e51f0a7fa50be197d2f3b9
  • Loading branch information
miguelmartin75 authored and facebook-github-bot committed Jul 12, 2020
1 parent 67ab6ea commit 070efc8
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 51 deletions.
79 changes: 63 additions & 16 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from ..._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from .._utils.attribution import PerturbationAttribution
from .._utils.common import _find_output_mode_and_verify, _format_input_baseline
from .._utils.common import _format_input_baseline


class FeatureAblation(PerturbationAttribution):
Expand Down Expand Up @@ -172,8 +172,11 @@ def attribute(
device contain at most
(perturbations_per_eval * #examples) / num_devices
samples.
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
If the forward function's number of outputs does not
change as the batch size grows (e.g. if it outputs a
scalar value), you must set perturbations_per_eval to 1
and use a single feature mask to describe the features
for all examples in the batch.
Default: 1
**kwargs (Any, optional): Any additional arguments used by child
classes of FeatureAblation (such as Occlusion) to construct
Expand Down Expand Up @@ -250,11 +253,26 @@ def attribute(
initial_eval = _run_forward(
self.forward_func, inputs, target, additional_forward_args
)
agg_output_mode = _find_output_mode_and_verify(
initial_eval, num_examples, perturbations_per_eval, feature_mask

agg_output_mode = FeatureAblation._find_output_mode(
perturbations_per_eval, feature_mask
)

# get as a 2D tensor (if it is not a scalar)
if isinstance(initial_eval, torch.Tensor):
initial_eval = initial_eval.reshape(1, -1)
num_outputs = initial_eval.shape[1]
else:
num_outputs = 1

if not agg_output_mode:
initial_eval = initial_eval.reshape(1, num_examples)
assert (
isinstance(initial_eval, torch.Tensor)
and num_outputs == num_examples
), (
"expected output of `forward_func` to have"
+ "`batch_size` elements for non aggregate mode"
)

# Initialize attribution totals and counts
attrib_type = cast(
Expand All @@ -263,17 +281,16 @@ def attribute(
if isinstance(initial_eval, Tensor)
else type(initial_eval),
)

total_attrib = [
torch.zeros_like(
input[0:1] if agg_output_mode else input, dtype=attrib_type
)
torch.zeros((num_outputs,) + input.shape[1:], dtype=attrib_type)
for input in inputs
]

# Weights are used in cases where ablations may be overlapping.
if self.use_weights:
weights = [
torch.zeros_like(input[0:1] if agg_output_mode else input).float()
torch.zeros((num_outputs,) + input.shape[1:]).float()
for input in inputs
]

Expand Down Expand Up @@ -305,18 +322,26 @@ def attribute(
current_target,
current_add_args,
)
# eval_diff dimensions: (#features in batch, #num_examples, 1,.. 1)
# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the inputs
# tensor.
if agg_output_mode:
if not isinstance(modified_eval, torch.Tensor):
eval_diff = initial_eval - modified_eval
else:
if not agg_output_mode:
real_pertubations_per_eval = (
current_inputs[0].shape[0] // inputs[0].shape[0]
)
assert (
modified_eval.numel()
== real_pertubations_per_eval * num_outputs
), """expected output of forward_func to grow with
batch_size. If this is not the case for your model
please set perturbations_per_eval = 1"""

eval_diff = (
initial_eval - modified_eval.reshape(-1, num_examples)
).reshape(
(-1, num_examples) + (len(inputs[i].shape) - 1) * (1,)
)
initial_eval - modified_eval.reshape((-1, num_outputs))
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
if self.use_weights:
weights[i] += current_mask.float().sum(dim=0)
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
Expand Down Expand Up @@ -479,3 +504,25 @@ def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
torch.max(input_mask).item() + 1,
input_mask,
)

@staticmethod
def _find_output_mode(
perturbations_per_eval: int,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
) -> bool:
"""
Returns True if the output mode is "aggregation output mode"
Aggregation output mode is defined as: when there is no 1:1 correspondence
with the `num_examples` (`batch_size`) and the amount of outputs your model
produces, i.e. the model output does not grow in size as the input becomes
larger.
We assume this is the case if `perturbations_per_eval == 1`
and your feature mask is None or is associated to all
examples in a batch (fm.shape[0] == 1 for all fm in feature_mask).
"""
return perturbations_per_eval == 1 and (
feature_mask is None
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
)
Loading

0 comments on commit 070efc8

Please sign in to comment.