Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-task attribution for Shapley #1173

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,31 @@ def attribute(
)
attr_progress.update(0)

initial_eval = _run_forward(
initial_eval = self._strict_run_forward(
self.forward_func, baselines, target, additional_forward_args
)

if show_progress:
attr_progress.update()

agg_output_mode = _find_output_mode_and_verify(
initial_eval, num_examples, perturbations_per_eval, feature_mask
initial_eval,
num_examples,
perturbations_per_eval,
feature_mask,
allow_multi_outputs=True,
)

# Initialize attribution totals and counts
output_shape = initial_eval.shape
n_outputs = initial_eval.numel()

# attr shape (*output_shape, *input_feature_shape)
total_attrib = [
torch.zeros_like(
input[0:1] if agg_output_mode else input, dtype=torch.float
torch.zeros(
(*output_shape, *input.shape[1:]),
dtype=torch.float,
device=inputs[0].device,
)
for input in inputs
]
Expand Down Expand Up @@ -349,7 +359,7 @@ def attribute(
)
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = _run_forward(
modified_eval = self._strict_run_forward(
self.forward_func,
current_inputs,
current_target,
Expand All @@ -362,23 +372,35 @@ def attribute(
eval_diff = modified_eval - prev_results
prev_results = modified_eval
else:
# when perturb_per_eval > 1, every num_examples stands for
# one perturb. Since the perturbs are from a consecutive
# perumuation, each diff of a perturb is its eval minus
# the eval of the previous perturb
all_eval = torch.cat((prev_results, modified_eval), dim=0)
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
prev_results = all_eval[-num_examples:]

for j in range(len(total_attrib)):
current_eval_diff = eval_diff
if not agg_output_mode:
# current_eval_diff dimensions:
# (#features in batch, #num_examples, 1,.. 1)
# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the
# inputs tensor.
current_eval_diff = current_eval_diff.reshape(
(-1, num_examples) + (len(inputs[j].shape) - 1) * (1,)
)
total_attrib[j] += (
current_eval_diff * current_masks[j].float()
).sum(dim=0)
# format eval_diff to shape
# (n_perturb, n_outputs, 1,.. 1)
# where n_perturb may not be perturb_per_eval
# Append n_input_feature dim of 1 to make the tensor
# have the same dim as the mask tensor.
formatted_eval_diff = eval_diff.reshape(
(-1, n_outputs) + (len(inputs[j].shape) - 1) * (1,)
)

# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
# aggregate n_perturb
cur_attr = (formatted_eval_diff * current_masks[j].float()).sum(
dim=0
)

# (n_outputs, *input_feature_shape) ->
# (*output_shape, *input_feature_shape)
total_attrib[j] += cur_attr.reshape(
(*output_shape, *cur_attr.shape[1:])
)

if show_progress:
attr_progress.close()
Expand Down Expand Up @@ -476,6 +498,31 @@ def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * n_samples

def _strict_run_forward(self, *args, **kwargs) -> Tensor:
"""
A temp wrapper for global _run_forward util to force forward output
type assertion & conversion.
Remove after the strict logic is supported by all attr classes
"""
forward_output = _run_forward(*args, **kwargs)
if isinstance(forward_output, Tensor):
# format scalar to shape (1) so we can always assume non-empty output_shape
if not forward_output.shape:
forward_output = forward_output.reshape(1)

return forward_output

output_type = type(forward_output)
assert output_type is int or output_type is float, (
"the return of forward_func must be a tensor, int, or float,"
f" received: {forward_output}"
)

# using python built-in type as torch dtype
# int -> torch.int64, float -> torch.float64
# ref: https://github.com/pytorch/pytorch/pull/21215
return torch.tensor([forward_output], dtype=output_type)


class ShapleyValues(ShapleyValueSampling):
"""
Expand Down
8 changes: 5 additions & 3 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def _find_output_mode_and_verify(
num_examples: int,
perturbations_per_eval: int,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
allow_multi_outputs: bool = False,
) -> bool:
"""
This method identifies whether the model outputs a single output for a batch
Expand Down Expand Up @@ -346,9 +347,10 @@ def _find_output_mode_and_verify(
)
else:
agg_output_mode = False
assert (
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
if not allow_multi_outputs:
assert (
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
return agg_output_mode


Expand Down
37 changes: 37 additions & 0 deletions tests/attr/test_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,43 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
perturbations_per_eval=(1, 2, 3),
)

def test_shapley_sampling_multi_task_output(self) -> None:
# return shape (batch size, 2)
net1 = BasicModel_MultiLayer()

# return shape (batch size, 4)
def forward_func(*args, **kwargs):
net_output = net1(*args, **kwargs)
batch_size = net_output.size(0)
constant = torch.ones(batch_size, 2)
output = torch.cat(
[
net_output,
constant,
],
dim=-1,
)
return output

inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)

self._shapley_test_assert(
forward_func,
inp,
[
[
[76.66666, 196.66666, 116.66666],
[76.66666, 196.66666, 116.66666],
[0, 0, 0],
[0, 0, 0],
]
],
target=None, # no target, multi-task output for all classes
perturbations_per_eval=(1, 2, 3),
n_samples=150,
test_true_shapley=True,
)

# Remaining tests are for cases where forward function returns a scalar
# per batch, as either a float, integer, 0d tensor or 1d tensor.
def test_single_shapley_batch_scalar_float(self) -> None:
Expand Down
Loading