Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate forward_fun output shape in FeatureAblation #1091

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 42 additions & 43 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def __init__(self, forward_func: Callable) -> None:
PerturbationAttribution.__init__(self, forward_func)
self.use_weights = False

# only used when perturbations_per_eval > 1, where the 1st dim of forward_func's
# output must grow as the input batch size. If forward's output is aggregated,
# we cannot expand the input to include more perturbations in one call.
# If it's False, we will force the validation by comparing the outpus of
# the original input and the modified input whose batch size expanded based on
# perturbations_per_eval. Set the flag to True if the output of the modified
# input grow as expected. Once it turns to True, we will assume the model's
# behavior stays consistent and no longer check again
self._is_output_shape_valid = False

@log_usage()
def attribute(
self,
Expand Down Expand Up @@ -291,21 +301,10 @@ def attribute(

# flatten eval outputs into 1D (n_outputs)
# add the leading dim for n_feature_perturbed
initial_eval = initial_eval.reshape(1, -1)

agg_output_mode = FeatureAblation._find_output_mode(
perturbations_per_eval, feature_mask
)

if not agg_output_mode:
assert n_outputs == num_examples, (
"expected output of `forward_func` to have "
+ "`batch_size` elements for perturbations_per_eval > 1 "
+ "and all feature_mask.shape[0] > 1"
)
flattened_initial_eval = initial_eval.reshape(1, -1)

# Initialize attribution totals and counts
attrib_type = cast(dtype, initial_eval.dtype)
attrib_type = cast(dtype, flattened_initial_eval.dtype)

total_attrib = [
# attribute w.r.t each output element
Expand Down Expand Up @@ -362,21 +361,43 @@ def attribute(
if show_progress:
attr_progress.update()

if not agg_output_mode:
# current_batch_size is not n_examples
# it may get expanded by n_feature_perturbed
# if perturbations_per_eval > 1, the output shape must grow with
# input and not be aggregated
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
current_batch_size = current_inputs[0].shape[0]

# number of perturbation, which is not the same as
# perturbations_per_eval when not enough features to perturb
n_perturb = current_batch_size / num_examples

current_output_shape = modified_eval.shape

# use initial_eval as the forward of perturbations_per_eval = 1
initial_output_shape = initial_eval.shape

assert (
modified_eval.numel() == current_batch_size
), """expected output of forward_func to grow with
batch_size. If this is not the case for your model
please set perturbations_per_eval = 1"""
# check if the output is not a scalar
current_output_shape
and initial_output_shape
# check if the output grow in same ratio, i.e., not agg
and current_output_shape[0]
== n_perturb * initial_output_shape[0]
), (
"When perturbations_per_eval > 1, forward_func's output "
"should be a tensor whose 1st dim grow with the input "
f"batch size: when input batch size is {num_examples}, "
Copy link
Contributor Author

@aobo-y aobo-y Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in future we can be more strict here by requiring the 1st dim of output to be the same as input batch size, instead of just require it "grow with the input batch size".

For example, if the input batch is 2, the output's 1st dim can be 2, 4, or 7; currently it will be OK
when we expand the input batch size to 4, the output's 1st dim becomes 4, 8, or 14 respectively.

But the last 2 cases are pretty weird and unlikely to happen. It may be easier to always require output's 1st dim to be batch size and it must be the same as input batch size in this "non-aggregation mode". So if the input batch is 2, the output's 1 dim must be 2 if perturbations_per_eval > 1

cc @vivekmig @NarineK

f"the output shape is {initial_output_shape}; "
f"when input batch size is {current_batch_size}, "
f"the output shape is {current_output_shape}"
)

self._is_output_shape_valid = True

# reshape the leading dim for n_feature_perturbed
# flatten each feature's eval outputs into 1D of (n_outputs)
modified_eval = modified_eval.reshape(-1, n_outputs)
# eval_diff in shape (n_feature_perturbed, n_outputs)
eval_diff = initial_eval - modified_eval
eval_diff = flattened_initial_eval - modified_eval

# append the shape of one input example
# to make it broadcastable to mask
Expand Down Expand Up @@ -572,28 +593,6 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs):
for inp, mask in zip(inputs, feature_mask)
)

@staticmethod
def _find_output_mode(
perturbations_per_eval: int,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
) -> bool:
"""
Returns True if the output mode is "aggregation output mode"

Aggregation output mode is defined as: when there is no 1:1 correspondence
with the `num_examples` (`batch_size`) and the amount of outputs your model
produces, i.e. the model output does not grow in size as the input becomes
larger.

We assume this is the case if `perturbations_per_eval == 1`
and your feature mask is None or is associated to all
examples in a batch (fm.shape[0] == 1 for all fm in feature_mask).
"""
return perturbations_per_eval == 1 and (
feature_mask is None
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
)

def _strict_run_forward(self, *args, **kwargs) -> Tensor:
"""
A temp wrapper for global _run_forward util to force forward output
Expand Down
11 changes: 0 additions & 11 deletions tests/attr/test_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,6 @@ def forward_func(inp):
with self.assertRaises(AssertionError):
_ = ablation.attribute(inp, perturbations_per_eval=2)

def test_error_agg_mode_incorrect_fm(self) -> None:
def forward_func(inp):
return inp[0].unsqueeze(0)

inp = torch.tensor([[1, 2, 3], [4, 5, 6]])
mask = torch.tensor([[0, 1, 2], [0, 0, 1]])
Copy link
Contributor Author

@aobo-y aobo-y Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks when rowwise mask is given, the model must not be in "aggregated mode". But I don't think mask has anything to do with it. Models of both batch-mode and aggregated-mode may need specify mask. So I simply deleted this test.


ablation = FeatureAblation(forward_func)
with self.assertRaises(AssertionError):
_ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask)

def test_empty_sparse_features(self) -> None:
ablation_algo = FeatureAblation(BasicModelWithSparseInputs())
inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])
Expand Down