Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert forward return to tensor in FeatureAblation #1049

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def attribute(

# Computes initial evaluation with all features, which is compared
# to each ablated result.
initial_eval = _run_forward(
initial_eval = self._strict_run_forward(
self.forward_func, inputs, target, additional_forward_args
)

Expand All @@ -291,27 +291,21 @@ def attribute(

# flatten eval outputs into 1D (n_outputs)
# add the leading dim for n_feature_perturbed
if isinstance(initial_eval, Tensor):
initial_eval = initial_eval.reshape(1, -1)
initial_eval = initial_eval.reshape(1, -1)

agg_output_mode = FeatureAblation._find_output_mode(
perturbations_per_eval, feature_mask
)

if not agg_output_mode:
assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, (
assert n_outputs == num_examples, (
"expected output of `forward_func` to have "
+ "`batch_size` elements for perturbations_per_eval > 1 "
+ "and all feature_mask.shape[0] > 1"
)

# Initialize attribution totals and counts
attrib_type = cast(
dtype,
initial_eval.dtype
if isinstance(initial_eval, Tensor)
else type(initial_eval),
)
attrib_type = cast(dtype, initial_eval.dtype)

total_attrib = [
# attribute w.r.t each output element
Expand Down Expand Up @@ -358,7 +352,7 @@ def attribute(
# agg mode: (*initial_eval.shape)
# non-agg mode:
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
modified_eval = _run_forward(
modified_eval = self._strict_run_forward(
self.forward_func,
current_inputs,
current_target,
Expand All @@ -368,31 +362,29 @@ def attribute(
if show_progress:
attr_progress.update()

if not isinstance(modified_eval, torch.Tensor):
eval_diff = initial_eval - modified_eval
else:
if not agg_output_mode:
# current_batch_size is not n_examples
# it may get expanded by n_feature_perturbed
current_batch_size = current_inputs[0].shape[0]
assert (
modified_eval.numel() == current_batch_size
), """expected output of forward_func to grow with
batch_size. If this is not the case for your model
please set perturbations_per_eval = 1"""

# reshape the leading dim for n_feature_perturbed
# flatten each feature's eval outputs into 1D of (n_outputs)
modified_eval = modified_eval.reshape(-1, n_outputs)
# eval_diff in shape (n_feature_perturbed, n_outputs)
eval_diff = initial_eval - modified_eval

# append the shape of one input example
# to make it broadcastable to mask
eval_diff = eval_diff.reshape(
eval_diff.shape + (inputs[i].dim() - 1) * (1,)
)
eval_diff = eval_diff.to(total_attrib[i].device)
if not agg_output_mode:
# current_batch_size is not n_examples
# it may get expanded by n_feature_perturbed
current_batch_size = current_inputs[0].shape[0]
assert (
modified_eval.numel() == current_batch_size
), """expected output of forward_func to grow with
batch_size. If this is not the case for your model
please set perturbations_per_eval = 1"""

# reshape the leading dim for n_feature_perturbed
# flatten each feature's eval outputs into 1D of (n_outputs)
modified_eval = modified_eval.reshape(-1, n_outputs)
# eval_diff in shape (n_feature_perturbed, n_outputs)
eval_diff = initial_eval - modified_eval

# append the shape of one input example
# to make it broadcastable to mask
eval_diff = eval_diff.reshape(
eval_diff.shape + (inputs[i].dim() - 1) * (1,)
)
eval_diff = eval_diff.to(total_attrib[i].device)

if self.use_weights:
weights[i] += current_mask.float().sum(dim=0)

Expand Down Expand Up @@ -601,3 +593,24 @@ def _find_output_mode(
feature_mask is None
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
)

def _strict_run_forward(self, *args, **kwargs) -> Tensor:
"""
A temp wrapper for global _run_forward util to force forward output
type assertion & conversion.
Remove after the strict logic is supported by all attr classes
"""
forward_output = _run_forward(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It seems a bit confusing when seeing both the instance method and original method named as _run_forward, could consider renaming this one slightly, but either way is fine.

if isinstance(forward_output, Tensor):
return forward_output

output_type = type(forward_output)
assert output_type is int or output_type is float, (
"the return of forward_func must be a tensor, int, or float,"
f" received: {forward_output}"
)

# using python built-in type as torch dtype
# int -> torch.int64, float -> torch.float64
# ref: https://github.com/pytorch/pytorch/pull/21215
return torch.tensor(forward_output, dtype=output_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I inherit our original logic that passing python types to torch dtype, like dtype=float. But this is not an officially documented operation. Existing tests assume it must equal to dtype=torch.float64 https://github.com/pytorch/captum/blob/5f878af6a7/tests/attr/test_feature_ablation.py#L429

But this may be machine dependent https://docs.python.org/3.10/library/stdtypes.html#typesnumeric .

Floating point numbers are usually implemented using double in C; information about the precision and internal representation of floating point numbers for the machine on which your program is running is available in sys.float_info

Two other alternatives are:

  • explicitly map python type to torch dtype: float -> torch.float64
  • do not set dtype, rely on torch's default dtype (float32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
It looks like the type mapping is done explicitly on the C++ side using the PyObject type, so this shouldn't be affected by the internal representation. This is the logic for mapping:

  PyObject *obj = args[i];
  if (obj == (PyObject*)&PyFloat_Type) {
    return at::ScalarType::Double;
  }
  if (obj == (PyObject*)&PyBool_Type) {
    return at::ScalarType::Bool;
  }
  if (obj == (PyObject*)&PyLong_Type
#if PY_MAJOR_VERSION == 2
      || obj == (PyObject*)&PyInt_Type
#endif
  ) {
    return at::ScalarType::Long;
  }

So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the deep dive @vivekmig !
Then I will just add the comment to refer the mapping, also as a caveat.
After all, it is not a documented torch usage. May have breaking changes someday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, sounds good!