Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched metrics #351

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

davor10105
Copy link

Description

  • Implementing true batch processing for the available metrics
  • This follows the previously raised issue.
  • The improvements gained from the batched processing seem to hover at around 12.6x speed-up on CPU (see image below):

image

  • Further testing needs to be done on GPU (I do not have access to one at the moment :/)

Implemented changes

  • The following changes were made:
    • Removed evaluate_instance method in PixelFlipping, Monotonicity, MonotonicityCorrelation, FaithfulnessCorrelation and FaithfulnessEstimate classes and replaced the existing evaluate_batch methods with their "true" batch implementation
    • Added batched parameter to correlation_spearman, correlation_pearson and correlation_kendall_tausimilarity functions to support batch processing, batched parameter to get_baseline_dict to support batched baseline creation, and similarly added the same parameter to calculate_auc

Implementation validity

  • To verify that the batched implementation and the for-loop type implementation return the same (or close) results, each metric was invoked over the same sample 30 times (using different seeds) utilizing the batched implementation and the for-loop (unbatched) implementation. Firstly, a simple np.allclose check was made and PixelFlipping, Monotonicity and FaithfulnessEstimate were verified as valid. MonotonicityCorrelation and FaithfulnessCorrelation did not pass this test, as they include stochastic elements in their calculations. To verify their validity, a two-way t-test was utilized over the 30 runs for each sample of the respective implementations. The resulting p-values can be seen below:
pixel_flipping is VALID (all close)
monotonicity is VALID (all close)
monotonicity_correlation is INVALID (t-test) (p > 0.05 elements: 91.67%)
p-values
 [0.04856062        inf 0.71486889 0.15810997 0.526356   0.74277285
 0.07734966        inf 0.79329449 0.25064761 0.63218201 0.04654986
        inf        inf 0.84467497 0.11548366 0.27559421 0.29898369
 0.65902698        inf 0.30107825 0.6182171  0.9439362  0.26287838]
faithfulness_estimate is VALID (all close)
faithfulness_correlation is INVALID (t-test) (p > 0.05 elements: 83.33%)
p-values
 [0.20491799 0.01031497 0.83503175 0.2578373  0.30944439 0.87939701
 0.47807545 0.05502236 0.7996986  0.02895154 0.02735908 0.90580175
 0.14312554 0.25870476 0.09197358 0.08645692 0.50516479 0.68482126
 0.09873465 0.64258948 0.35872379 0.94694292 0.03611762 0.78976033]
  • The mean metric values for 92% and 83% of the batch samples showed no statistically significant difference between the two approaches, respectively. The reason for the lack of a 100% match is unclear. To further investigate, I conducted an additional experiment comparing the scores from two runs (each repeated 30 times) of the old loop-style implementation. Despite being identical implementations, the t-test revealed instances where the means of the two groups differed, specifically in the case of FaithfulnessCorrelation (see results below):
monotonicity_correlation is VALID (t-test)
faithfulness_correlation is INVALID (t-test) (p > 0.05 elements: 91.67%)
p-values
 [0.81952674 0.22507615 0.33738925 0.90864344 0.81490267 0.20901662
 0.81953659 0.17859758 0.62796515 0.04876216 0.13012812 0.58632642
 0.84153836 0.17535798 0.52076956 0.84998221 0.57439399 0.1188701
 0.53963641 0.77368691 0.68040962 0.9938426  0.00461874 0.07818067]
  • So the reason might just be that this metric is very unstable for some examples, but I am not sure. I do not think my batched implementation is wrong, since it does produce results that can be statistically verified in the majority of cases, which would be difficult to do "by chance".
  • Validation and visualization scripts are attached here - testing_utils.zip. batched_tests directory contains batch_implementation_verification.py script which runs the validation tests utilizing a copy of the repo (in the quantus directory also contained within the zip file) that has both the batched and the old implementation versions. Results of the runs mentioned above can be found in the results.pickle. This file is used by test_visualization.py to display the box visualization and check the validity of the batch implementation as described above.

Minimum acceptance criteria

  • Implementing batch processing for all other metrics and supporting functions
    @annahedstroem

@@ -1015,6 +1034,8 @@ def calculate_auc(values: np.array, dx: int = 1):
np.ndarray
Definite integral of values.
"""
if batched:
return np.trapz(values, dx=dx, axis=1)
return np.trapz(np.array(values), dx=dx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe, this could be simplified to smth like:

axis = 1 if batched else None
return np.trapz(np.asarray(values), dx=dx, axis=axis)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified in the latest commit.

# Indices
indices = params["indices"]

if isinstance(expected, (int, float)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected value is provided by the pytest.mark.parametrize, and its type is known beforehand. Why do we need this check?

@@ -30,6 +30,11 @@ def input_zeros_2d_3ch_flattened():
return np.zeros(shape=(3, 224, 224)).flatten()


@pytest.fixture
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fixture used only in one place?
If that's the case, please inline it.

x_batch_shape = x_batch.shape
for perturbation_step_index in range(n_perturbations):
# Perturb input by indices of attributions.
a_ix = a_indices[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_ix is an array with shape (batch_size, n_features*n_perturbations), right?
I'd suggest we create a view with shape (batch_size, n_features, n_perturbations).
Then we can index each step with [...,perturbation_step_index] instead of manually calculating offsets into the array

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_indices is an array with shape (batch_size, n_features) and the resulting a_ixs are of shape (batch_size, self.features_in_step). I believe calculating the offsets manually here is the only option.

@@ -118,6 +118,58 @@ def baseline_replacement_by_indices(
return arr_perturbed


def batch_baseline_replacement_by_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import numpy.typing as npt

def batch_baseline_replacement_by_indices(
    arr: np.ndarray,
    indices: np.ndarray,
    perturb_baseline: npt.ArrayLike,
    **kwargs,
) -> np.ndarray:


# Predict on input.
x_input = model.shape_input(
x_batch, x_batch.shape, channel_first=True, batched=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik channel_first is a models parameter, so we should not hardcode it. @annahedstroem could you please help us on that one 🙃

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was hardcoded in the original implementation as well. Is that a bug?

# Randomly mask by subset size.
a_ix = np.stack(
[
np.random.choice(n_features, self.subset_size, replace=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mb add fixed PRNG seed for reproducibility?

pred_deltas = np.stack(pred_deltas, axis=1)
att_sums = np.stack(att_sums, axis=1)

similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't batch_baseline_replacement_by_indices always batched? Why do we need batched=True argument?

@annahedstroem have you ever used a different similarity_func here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the batched=True argument goes into a similarity function (for example correlation_pearson), not the batch_baseline_replacement_by_indices. Similarity functions can be batched and not batched (at the moment at least) so this argument is needed here.

return_shape=(
batch_size,
n_features,
), # TODO. Double-check this over using = (1,).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this TODO would need a bit more detail

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a relic of the past implementation, accidentally left it in the new one as well. I have deleted the TODO in the latest commit.

if batched:
assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D"
# No support for axis currently, so just iterating over the batch
return np.array([scipy.stats.kendalltau(a_i, b_i)[0] for a_i, b_i in zip(a, b)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mb we could use np.vectorize (https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) for this one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be used but I also like the simplicity of @davor10105's suggestion!

@codecov-commenter
Copy link

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 95.88235% with 7 lines in your changes missing coverage. Please review.

Project coverage is 91.29%. Comparing base (6857561) to head (4f44510).
Report is 16 commits behind head on main.

Files Patch % Lines
quantus/helpers/utils.py 72.72% 3 Missing ⚠️
...s/metrics/faithfulness/faithfulness_correlation.py 96.15% 1 Missing ⚠️
...ntus/metrics/faithfulness/faithfulness_estimate.py 96.55% 1 Missing ⚠️
quantus/metrics/faithfulness/monotonicity.py 95.45% 1 Missing ⚠️
quantus/metrics/faithfulness/pixel_flipping.py 96.15% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #351      +/-   ##
==========================================
+ Coverage   91.15%   91.29%   +0.13%     
==========================================
  Files          66       66              
  Lines        3925     4010      +85     
==========================================
+ Hits         3578     3661      +83     
- Misses        347      349       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -14,7 +14,9 @@
import skimage


def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float:
def correlation_spearman(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super!

Copy link
Member

@annahedstroem annahedstroem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great work @davor10105, looking forward to our chat.

@@ -139,7 +139,7 @@ def __init__(

# Save metric-specific attributes.
if perturb_func is None:
perturb_func = baseline_replacement_by_indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss: where in the code should we make it explicit for the user that they no longer can use any other perturb function that batch_baseline_replacement_by_indices.

@annahedstroem
Copy link
Member

Do we know why most of the python checks are failing? Thanks

@davor10105
Copy link
Author

Do we know why most of the python checks are failing? Thanks

It seems that the installed versions of scipy on python3.8 and 3.9 do not have the axis parameter in the pearsonr correlation. As for 3.11, I am not sure, the test fails early on fixture loading, I will look into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants