Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate forward_fun output shape in FeatureAblation #1091

Closed
wants to merge 3 commits into from

Conversation

aobo-y
Copy link
Contributor

@aobo-y aobo-y commented Dec 14, 2022

Correctly validate the forward_fun's output shape in FeatureAblation (and FeaturePermutation)

Abandoned previous flawed assumption of "aggregation mode", which forbid the support for multi-outputs models (ref #1047)

New logic does not care output shape when perturbations_per_eval == 1. Only when perturbations_per_eval > 1, it require "Non-Aggregation mode", which is defined as the 1st dim of the model's output should grow with the input's batch size in the same ratio. This is achieved by actually comparing the output shape of 2 different inputs instead of making any assumption based on other user config:

  • The baseline output is from the initial eval with the original inputs which we have to run anyway.
  • The expanded output is from the 1st ablated eval whose input batch size has been expanded for more feature perturbation.

This way does not even introduce any extra forward calls.

), (
"When perturbations_per_eval > 1, forward_func's output "
"should be a tensor whose 1st dim grow with the input "
f"batch size: when input batch size is {num_examples}, "
Copy link
Contributor Author

@aobo-y aobo-y Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in future we can be more strict here by requiring the 1st dim of output to be the same as input batch size, instead of just require it "grow with the input batch size".

For example, if the input batch is 2, the output's 1st dim can be 2, 4, or 7; currently it will be OK
when we expand the input batch size to 4, the output's 1st dim becomes 4, 8, or 14 respectively.

But the last 2 cases are pretty weird and unlikely to happen. It may be easier to always require output's 1st dim to be batch size and it must be the same as input batch size in this "non-aggregation mode". So if the input batch is 2, the output's 1 dim must be 2 if perturbations_per_eval > 1

cc @vivekmig @NarineK

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@aobo-y aobo-y changed the title validate forward output in FeatureAblation validate forward output shape in FeatureAblation Dec 14, 2022
return inp[0].unsqueeze(0)

inp = torch.tensor([[1, 2, 3], [4, 5, 6]])
mask = torch.tensor([[0, 1, 2], [0, 0, 1]])
Copy link
Contributor Author

@aobo-y aobo-y Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks when rowwise mask is given, the model must not be in "aggregated mode". But I don't think mask has anything to do with it. Models of both batch-mode and aggregated-mode may need specify mask. So I simply deleted this test.

@aobo-y aobo-y requested a review from vivekmig December 14, 2022 18:56
@aobo-y aobo-y requested a review from NarineK December 14, 2022 18:57
@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@aobo-y aobo-y changed the title validate forward output shape in FeatureAblation validate forward_fun output shape in FeatureAblation Dec 19, 2022
@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@aobo-y merged this pull request in 761a219.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants