Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support normalize in metric infidelity #639

Closed
wants to merge 5 commits into from
Closed

Conversation

aobo-y
Copy link
Contributor

@aobo-y aobo-y commented Mar 24, 2021

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@vivekmig vivekmig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for adding this @aobo-y ! Just some minor nits on documentation.

@@ -345,6 +346,15 @@ def infidelity(
`input batch size * n_perturb_samples`.

Default: None
normalize (bool, optional): Normalize the dot product of the input
perturbation and the attribution so the infedelity value is invariant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: infidelity

perturbation and the attribution so the infedelity value is invariant
to constant scaling of the attribution values. The normalization factor
is defined as the ratio of two mean values across all perturbations:
`mean(dot product * func value diff) / mean(dot product * dot product)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be possible to make this line a little more detailed to explain this factor for new users? E.g. it may not be immediately clear that func value diff is the same as the difference between the predictor function at its input and perturbed input described above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I'd recommend using paper notation with latex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, i can use latex here

max_examples_per_batch=max_examples_per_batch,
)
return metrics_sum * 1 / n_perturb_samples

attr_times_perturb_sums, perturbed_fwd_diffs = metrics_sum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can include this above as return of _divide_and_aggregate_metrics ? metrics_sum isn't really applicable anymore

Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this PR @aobo-y!

If I understand this implementation correctly we aggregate all inputs and all perturbations in the memory which can lead to out of memory very quickly. The idea of using torch.add or torch.mean as an aggregate function was to help to scale the implementation and avoid out of memory as much as possible.

In _next_infidelity_tensors you have access to normalize input argument (I've just verified it. In python you are still running in the same context and you have access to normalize argument in that context). You can keep track of np.mean(pdt_diff * exp_sum) and np.mean(exp_sum * exp_sum) in _next_infidelity_tensors too and ultimately apply final beta per example in the end. Let me know what you think.

perturbation and the attribution so the infedelity value is invariant
to constant scaling of the attribution values. The normalization factor
is defined as the ratio of two mean values across all perturbations:
`mean(dot product * func value diff) / mean(dot product * dot product)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I'd recommend using paper notation with latex.

@aobo-y
Copy link
Contributor Author

aobo-y commented Apr 4, 2021

@NarineK if you mean I keep all the perturbation results in the memory, then yes, I am aware of the consumption of memory. If the perturbation number is too large, we will surely end with out of memory.

In python, I can access the context in the nested aggregation function and know if we need to normalize. Unfortunately, it is not the reason that I have to keep all perturbation result. If I understand correctly, the author's normalization applies the beta to each perturbation pair (pdt_diff, exp_sum), not each example.

What you suggested does help calculate the two mean which will be used for calculating beta. I agree we can aggregate them in each forward if that's all we need. However, in the end, the author uses the beta to time each perturbation, then calculate the MSE. In other word, you will only know beta after all perturbations, but then you still need each perturbation pair to calculate beta * pdt_diff - exp_sum.
https://github.com/chihkuanyeh/saliency_evaluation/blob/44a66e2531f30b803be3bf5b0786971b7e7f72a1/infid_sen_utils.py#L298

Mathematically, what we need is mean((beta * pdt_diff - exp_sum) ^ 2). I don't think we can get it with pre-aggregated mean(pdf_diff), mean(exp_sum), or mean((pdt_diff - exp_sum) ^ 2). Let me know if it makes sense.

@vivekmig
Copy link
Contributor

vivekmig commented Apr 4, 2021

@NarineK if you mean I keep all the perturbation results in the memory, then yes, I am aware of the consumption of memory. If the perturbation number is too large, we will surely end with out of memory.

In python, I can access the context in the nested aggregation function and know if we need to normalize. Unfortunately, it is not the reason that I have to keep all perturbation result. If I understand correctly, the author's normalization applies the beta to each perturbation pair (pdt_diff, exp_sum), not each example.

What you suggested does help calculate the two mean which will be used for calculating beta. I agree we can aggregate them in each forward if that's all we need. However, in the end, the author uses the beta to time each perturbation, then calculate the MSE. In other word, you will only know beta after all perturbations, but then you still need each perturbation pair to calculate beta * pdt_diff - exp_sum.
https://github.com/chihkuanyeh/saliency_evaluation/blob/44a66e2531f30b803be3bf5b0786971b7e7f72a1/infid_sen_utils.py#L298

Mathematically, what we need is mean((beta * pdt_diff - exp_sum) ^ 2). I don't think we can get it with pre-aggregated mean(pdf_diff), mean(exp_sum), or mean((pdt_diff - exp_sum) ^ 2). Let me know if it makes sense.

I agree that maintaining just the means necessary for beta would not be sufficient to avoid maintaining results per perturbation sample. I think the approach to not store the sample-wise results would be to maintain mean(attr_times_perturb_sums^2), mean(attr_times_perturb_sums * perturbed_fwd_diffs) and mean(perturbed_fwd_diffs ^ 2) . With these means, it should then be possible to compute both the beta and the final mean[ (beta * attr_times_perturb_sums - perturbed_fwd_diffs)^2 ] as:
beta^2 * mean(attr_times_perturb_sums^2) - 2 * beta * mean(attr_times_perturb_sums * perturbed_fwd_diffs) + mean(perturbed_fwd_diffs ^ 2)

based on the expansion of (a - b)^2

This approach would avoid the additional memory, but at the tradeoff of a potentially trickier formulation to follow. The additional memory used here should be on the order of batch_size * n_perturb_samples (full input perturbations wouldn't be maintained), so this shouldn't be a large issue / lead to OOMs with typical use cases. But to be on the safer side, if we expect potentially larger values, might be worth considering the alternative approach. What do you think @NarineK , @aobo-y ?

@NarineK
Copy link
Contributor

NarineK commented Apr 4, 2021

@NarineK if you mean I keep all the perturbation results in the memory, then yes, I am aware of the consumption of memory. If the perturbation number is too large, we will surely end with out of memory.

In python, I can access the context in the nested aggregation function and know if we need to normalize. Unfortunately, it is not the reason that I have to keep all perturbation result. If I understand correctly, the author's normalization applies the beta to each perturbation pair (pdt_diff, exp_sum), not each example.

What you suggested does help calculate the two mean which will be used for calculating beta. I agree we can aggregate them in each forward if that's all we need. However, in the end, the author uses the beta to time each perturbation, then calculate the MSE. In other word, you will only know beta after all perturbations, but then you still need each perturbation pair to calculate beta * pdt_diff - exp_sum.
https://github.com/chihkuanyeh/saliency_evaluation/blob/44a66e2531f30b803be3bf5b0786971b7e7f72a1/infid_sen_utils.py#L298

Mathematically, what we need is mean((beta * pdt_diff - exp_sum) ^ 2). I don't think we can get it with pre-aggregated mean(pdf_diff), mean(exp_sum), or mean((pdt_diff - exp_sum) ^ 2). Let me know if it makes sense.

ttr_times_perturb_sums * perturbed_fwd_diffs) + mean(perturbed_fwd_diffs ^ 2)

I saw that beta is a scalar and it looked to me that we should be able to do that sample based. Sorry for the confusion.

sum(a - b) = sum(a) - sum(b), so if we have sum(a) and sum(b) we can compute sum(beta * a - b) = beta * sum(a) - sum(b)
In terms of squares as @vivekmig mentioned sum([beta * a - b]^2) = beta^2 * sum(a^2) - 2 * beta * sum(a * b) + sum(b^2). I think that it is important to be memory aware and we shouldn't sacrifice memory for a normalization factor.

@aobo-y aobo-y force-pushed the master branch 2 times, most recently from 1029503 to c6d8e54 Compare April 5, 2021 20:44
\mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ]
}{
\mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ]
} $$.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiled equation
Screen Shot 2021-04-05 at 4 43 53 PM

@aobo-y
Copy link
Contributor Author

aobo-y commented Apr 6, 2021

@NarineK @vivekmig The unittest test_classification_infidelity_tpl_target_w_baseline passed in latest pytorch version but failed in v1.2 and v1.3, because of the precision violates the 0.05 tolerance in the those older version of pytorch.

I tried to study where the differences come from and found an inconsistent behavior of IntegratedGradients. The attribution is dtype.float64 in latest pytorch but dtype.float32 in older version. It is caused by the ambiguous call of torch.tensor(step_sizes) at https://github.com/pytorch/captum/blob/master/captum/attr/_core/integrated_gradients.py#L362

just in case this is unknown before.

For this PR, this means due to the precision in float32, the difference between mean(a^-2ab+b^2) and mean((a-b)^2) can be larger than 0.05.

@aobo-y
Copy link
Contributor Author

aobo-y commented Apr 6, 2021

Here is the detailed example https://github.com/pytorch/captum/blob/master/tests/metrics/test_infidelity.py#L268
In pytorch-1.3, the infidelity score is tensor([0.1068, 0.0208, 0.0000, 0.0833]) if we keep the default float attribution but if I convert it to double attr = attr.double(), the infidelity will become tensor([1.0687e-01, 2.0955e-09, 8.3819e-09, 8.3819e-09], dtype=torch.float64) and successfully pass the tests.

But in the latest pytorch, both float and double attribution lead to the same infidelity:tensor([0.1069, 0.0000, 0.0000, 0.0000]) and tensor([0.1069, 0.0000, 0.0000, 0.0000], dtype=torch.float64)

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

)
else:
# returns (a-b)^2 if no need to normalize
return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to the above issue, I keep two ways for aggregation based on normalize instead of always using a^2-2ab+b^2:

  • (a-b)^2 if not normalize
  • a^2, ab, b^2 if normalize

This change allows me to pass the tests.
But still worth noting that in older version of pytorch, when normalize, a^2-2ab+b^2 will lose some precision compared with direct (a-b)^2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, sounds good, this fix looks good to me!

Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this, @aobo-y! Looks good to me! I left one small nit comment.

beta_num = agg_tensors[1]
beta_denorm = agg_tensors[0]

beta_denorm[beta_denorm == 0] += 1e-10 # safe divide
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use a common function for safe divide: safe_div
https://github.com/pytorch/captum/blob/master/captum/_utils/common.py#L26

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@aobo-y
Copy link
Contributor Author

aobo-y commented Apr 19, 2021

Updated the PR with safe_div as suggested.
But opened another issue abt safe_div for discussion & future improvement.

@facebook-github-bot
Copy link
Contributor

@aobo-y merged this pull request in ed4b9ab.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants