-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched metrics #351
base: main
Are you sure you want to change the base?
Batched metrics #351
Conversation
quantus/helpers/utils.py
Outdated
@@ -1015,6 +1034,8 @@ def calculate_auc(values: np.array, dx: int = 1): | |||
np.ndarray | |||
Definite integral of values. | |||
""" | |||
if batched: | |||
return np.trapz(values, dx=dx, axis=1) | |||
return np.trapz(np.array(values), dx=dx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe, this could be simplified to smth like:
axis = 1 if batched else None
return np.trapz(np.asarray(values), dx=dx, axis=axis)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified in the latest commit.
# Indices | ||
indices = params["indices"] | ||
|
||
if isinstance(expected, (int, float)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expected
value is provided by the pytest.mark.parametrize
, and its type is known beforehand. Why do we need this check?
@@ -30,6 +30,11 @@ def input_zeros_2d_3ch_flattened(): | |||
return np.zeros(shape=(3, 224, 224)).flatten() | |||
|
|||
|
|||
@pytest.fixture |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this fixture used only in one place?
If that's the case, please inline it.
x_batch_shape = x_batch.shape | ||
for perturbation_step_index in range(n_perturbations): | ||
# Perturb input by indices of attributions. | ||
a_ix = a_indices[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a_ix is an array with shape (batch_size, n_features*n_perturbations)
, right?
I'd suggest we create a view with shape (batch_size, n_features, n_perturbations)
.
Then we can index each step with [...,perturbation_step_index]
instead of manually calculating offsets into the array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a_indices
is an array with shape (batch_size, n_features)
and the resulting a_ix
s are of shape (batch_size, self.features_in_step)
. I believe calculating the offsets manually here is the only option.
@@ -118,6 +118,58 @@ def baseline_replacement_by_indices( | |||
return arr_perturbed | |||
|
|||
|
|||
def batch_baseline_replacement_by_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import numpy.typing as npt
def batch_baseline_replacement_by_indices(
arr: np.ndarray,
indices: np.ndarray,
perturb_baseline: npt.ArrayLike,
**kwargs,
) -> np.ndarray:
|
||
# Predict on input. | ||
x_input = model.shape_input( | ||
x_batch, x_batch.shape, channel_first=True, batched=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik channel_first
is a models parameter, so we should not hardcode it. @annahedstroem could you please help us on that one 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was hardcoded in the original implementation as well. Is that a bug?
# Randomly mask by subset size. | ||
a_ix = np.stack( | ||
[ | ||
np.random.choice(n_features, self.subset_size, replace=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mb add fixed PRNG seed for reproducibility?
pred_deltas = np.stack(pred_deltas, axis=1) | ||
att_sums = np.stack(att_sums, axis=1) | ||
|
||
similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't batch_baseline_replacement_by_indices
always batched? Why do we need batched=True
argument?
@annahedstroem have you ever used a different similarity_func
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the batched=True
argument goes into a similarity function (for example correlation_pearson
), not the batch_baseline_replacement_by_indices
. Similarity functions can be batched and not batched (at the moment at least) so this argument is needed here.
return_shape=( | ||
batch_size, | ||
n_features, | ||
), # TODO. Double-check this over using = (1,). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this TODO
would need a bit more detail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a relic of the past implementation, accidentally left it in the new one as well. I have deleted the TODO in the latest commit.
if batched: | ||
assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" | ||
# No support for axis currently, so just iterating over the batch | ||
return np.array([scipy.stats.kendalltau(a_i, b_i)[0] for a_i, b_i in zip(a, b)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mb we could use np.vectorize
(https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) for this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be used but I also like the simplicity of @davor10105's suggestion!
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #351 +/- ##
==========================================
+ Coverage 91.15% 91.29% +0.13%
==========================================
Files 66 66
Lines 3925 4010 +85
==========================================
+ Hits 3578 3661 +83
- Misses 347 349 +2 ☔ View full report in Codecov by Sentry. |
@@ -14,7 +14,9 @@ | |||
import skimage | |||
|
|||
|
|||
def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: | |||
def correlation_spearman( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great work @davor10105, looking forward to our chat.
@@ -139,7 +139,7 @@ def __init__( | |||
|
|||
# Save metric-specific attributes. | |||
if perturb_func is None: | |||
perturb_func = baseline_replacement_by_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss: where in the code should we make it explicit for the user that they no longer can use any other perturb function that batch_baseline_replacement_by_indices
.
Do we know why most of the python checks are failing? Thanks |
It seems that the installed versions of |
Description
Implemented changes
evaluate_instance
method inPixelFlipping
,Monotonicity
,MonotonicityCorrelation
,FaithfulnessCorrelation
andFaithfulnessEstimate
classes and replaced the existingevaluate_batch
methods with their "true" batch implementationbatched
parameter tocorrelation_spearman
,correlation_pearson
andcorrelation_kendall_tau
similarity functions to support batch processing,batched
parameter toget_baseline_dict
to support batched baseline creation, and similarly added the same parameter tocalculate_auc
Implementation validity
np.allclose
check was made andPixelFlipping
,Monotonicity
andFaithfulnessEstimate
were verified as valid.MonotonicityCorrelation
andFaithfulnessCorrelation
did not pass this test, as they include stochastic elements in their calculations. To verify their validity, a two-way t-test was utilized over the 30 runs for each sample of the respective implementations. The resulting p-values can be seen below:batched_tests
directory containsbatch_implementation_verification.py
script which runs the validation tests utilizing a copy of the repo (in thequantus
directory also contained within the zip file) that has both the batched and the old implementation versions. Results of the runs mentioned above can be found in theresults.pickle
. This file is used bytest_visualization.py
to display the box visualization and check the validity of the batch implementation as described above.Minimum acceptance criteria
@annahedstroem