-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert forward return to tensor in FeatureAblation #1049
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -279,7 +279,7 @@ def attribute( | |
|
||
# Computes initial evaluation with all features, which is compared | ||
# to each ablated result. | ||
initial_eval = _run_forward( | ||
initial_eval = self._strict_run_forward( | ||
self.forward_func, inputs, target, additional_forward_args | ||
) | ||
|
||
|
@@ -291,27 +291,21 @@ def attribute( | |
|
||
# flatten eval outputs into 1D (n_outputs) | ||
# add the leading dim for n_feature_perturbed | ||
if isinstance(initial_eval, Tensor): | ||
initial_eval = initial_eval.reshape(1, -1) | ||
initial_eval = initial_eval.reshape(1, -1) | ||
|
||
agg_output_mode = FeatureAblation._find_output_mode( | ||
perturbations_per_eval, feature_mask | ||
) | ||
|
||
if not agg_output_mode: | ||
assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, ( | ||
assert n_outputs == num_examples, ( | ||
"expected output of `forward_func` to have " | ||
+ "`batch_size` elements for perturbations_per_eval > 1 " | ||
+ "and all feature_mask.shape[0] > 1" | ||
) | ||
|
||
# Initialize attribution totals and counts | ||
attrib_type = cast( | ||
dtype, | ||
initial_eval.dtype | ||
if isinstance(initial_eval, Tensor) | ||
else type(initial_eval), | ||
) | ||
attrib_type = cast(dtype, initial_eval.dtype) | ||
|
||
total_attrib = [ | ||
# attribute w.r.t each output element | ||
|
@@ -358,7 +352,7 @@ def attribute( | |
# agg mode: (*initial_eval.shape) | ||
# non-agg mode: | ||
# (feature_perturbed * batch_size, *initial_eval.shape[1:]) | ||
modified_eval = _run_forward( | ||
modified_eval = self._strict_run_forward( | ||
self.forward_func, | ||
current_inputs, | ||
current_target, | ||
|
@@ -368,31 +362,29 @@ def attribute( | |
if show_progress: | ||
attr_progress.update() | ||
|
||
if not isinstance(modified_eval, torch.Tensor): | ||
eval_diff = initial_eval - modified_eval | ||
else: | ||
if not agg_output_mode: | ||
# current_batch_size is not n_examples | ||
# it may get expanded by n_feature_perturbed | ||
current_batch_size = current_inputs[0].shape[0] | ||
assert ( | ||
modified_eval.numel() == current_batch_size | ||
), """expected output of forward_func to grow with | ||
batch_size. If this is not the case for your model | ||
please set perturbations_per_eval = 1""" | ||
|
||
# reshape the leading dim for n_feature_perturbed | ||
# flatten each feature's eval outputs into 1D of (n_outputs) | ||
modified_eval = modified_eval.reshape(-1, n_outputs) | ||
# eval_diff in shape (n_feature_perturbed, n_outputs) | ||
eval_diff = initial_eval - modified_eval | ||
|
||
# append the shape of one input example | ||
# to make it broadcastable to mask | ||
eval_diff = eval_diff.reshape( | ||
eval_diff.shape + (inputs[i].dim() - 1) * (1,) | ||
) | ||
eval_diff = eval_diff.to(total_attrib[i].device) | ||
if not agg_output_mode: | ||
# current_batch_size is not n_examples | ||
# it may get expanded by n_feature_perturbed | ||
current_batch_size = current_inputs[0].shape[0] | ||
assert ( | ||
modified_eval.numel() == current_batch_size | ||
), """expected output of forward_func to grow with | ||
batch_size. If this is not the case for your model | ||
please set perturbations_per_eval = 1""" | ||
|
||
# reshape the leading dim for n_feature_perturbed | ||
# flatten each feature's eval outputs into 1D of (n_outputs) | ||
modified_eval = modified_eval.reshape(-1, n_outputs) | ||
# eval_diff in shape (n_feature_perturbed, n_outputs) | ||
eval_diff = initial_eval - modified_eval | ||
|
||
# append the shape of one input example | ||
# to make it broadcastable to mask | ||
eval_diff = eval_diff.reshape( | ||
eval_diff.shape + (inputs[i].dim() - 1) * (1,) | ||
) | ||
eval_diff = eval_diff.to(total_attrib[i].device) | ||
|
||
if self.use_weights: | ||
weights[i] += current_mask.float().sum(dim=0) | ||
|
||
|
@@ -601,3 +593,24 @@ def _find_output_mode( | |
feature_mask is None | ||
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) | ||
) | ||
|
||
def _strict_run_forward(self, *args, **kwargs) -> Tensor: | ||
""" | ||
A temp wrapper for global _run_forward util to force forward output | ||
type assertion & conversion. | ||
Remove after the strict logic is supported by all attr classes | ||
""" | ||
forward_output = _run_forward(*args, **kwargs) | ||
if isinstance(forward_output, Tensor): | ||
return forward_output | ||
|
||
output_type = type(forward_output) | ||
assert output_type is int or output_type is float, ( | ||
"the return of forward_func must be a tensor, int, or float," | ||
f" received: {forward_output}" | ||
) | ||
|
||
# using python built-in type as torch dtype | ||
# int -> torch.int64, float -> torch.float64 | ||
# ref: https://github.com/pytorch/pytorch/pull/21215 | ||
return torch.tensor(forward_output, dtype=output_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I inherit our original logic that passing python types to torch dtype, like But this may be machine dependent https://docs.python.org/3.10/library/stdtypes.html#typesnumeric .
Two other alternatives are:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx for the deep dive @vivekmig ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It seems a bit confusing when seeing both the instance method and original method named as _run_forward, could consider renaming this one slightly, but either way is fine.