From e0e090ca8d98cc851bd640bbb289ae2653aff927 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sat, 12 Aug 2023 22:48:38 +0200 Subject: [PATCH 01/58] WIP --- quantus/helpers/constants.py | 5 +- quantus/metrics/__init__.py | 5 +- quantus/metrics/axiomatic/completeness.py | 68 +-- quantus/metrics/axiomatic/input_invariance.py | 5 +- quantus/metrics/axiomatic/non_sensitivity.py | 112 ++-- quantus/metrics/base.py | 365 ++++++------ quantus/metrics/base_batched.py | 546 +----------------- quantus/metrics/base_perturbed.py | 66 +-- quantus/metrics/complexity/complexity.py | 54 +- .../complexity/effective_complexity.py | 39 +- quantus/metrics/complexity/sparseness.py | 63 +- .../faithfulness/faithfulness_correlation.py | 112 ++-- .../faithfulness/faithfulness_estimate.py | 125 ++-- quantus/metrics/faithfulness/infidelity.py | 191 +++--- quantus/metrics/faithfulness/irof.py | 143 ++--- quantus/metrics/faithfulness/monotonicity.py | 119 ++-- .../faithfulness/monotonicity_correlation.py | 145 +++-- .../metrics/faithfulness/pixel_flipping.py | 117 ++-- .../faithfulness/region_perturbation.py | 17 +- quantus/metrics/faithfulness/road.py | 108 ++-- quantus/metrics/faithfulness/selectivity.py | 20 +- quantus/metrics/faithfulness/sensitivity_n.py | 123 ++-- quantus/metrics/faithfulness/sufficiency.py | 115 +--- .../localisation/attribution_localisation.py | 97 ++-- quantus/metrics/localisation/auc.py | 66 +-- quantus/metrics/localisation/focus.py | 98 ++-- quantus/metrics/localisation/pointing_game.py | 81 +-- .../localisation/relevance_mass_accuracy.py | 75 +-- .../localisation/relevance_rank_accuracy.py | 90 ++- .../localisation/top_k_intersection.py | 86 ++- .../model_parameter_randomisation.py | 80 +-- quantus/metrics/randomisation/random_logit.py | 63 +- quantus/metrics/robustness/avg_sensitivity.py | 9 +- quantus/metrics/robustness/consistency.py | 125 ++-- quantus/metrics/robustness/continuity.py | 25 +- .../robustness/local_lipschitz_estimate.py | 11 +- quantus/metrics/robustness/max_sensitivity.py | 9 +- .../robustness/relative_input_stability.py | 4 +- .../robustness/relative_output_stability.py | 5 +- .../relative_representation_stability.py | 4 +- tests/README.md | 5 + 41 files changed, 1255 insertions(+), 2341 deletions(-) diff --git a/quantus/helpers/constants.py b/quantus/helpers/constants.py index b440cb6d0..4112d0a21 100644 --- a/quantus/helpers/constants.py +++ b/quantus/helpers/constants.py @@ -7,8 +7,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import List, Dict - +from typing import List, Dict, Final, Mapping, Type from quantus.functions.loss_func import * from quantus.functions.normalise_func import * from quantus.functions.perturb_func import * @@ -16,7 +15,7 @@ from quantus.metrics import * -AVAILABLE_METRICS = { +AVAILABLE_METRICS: Final[Mapping[str, Mapping[str, Type[Metric]]]] = { "Faithfulness": { "Faithfulness Correlation": FaithfulnessCorrelation, "Faithfulness Estimate": FaithfulnessEstimate, diff --git a/quantus/metrics/__init__.py b/quantus/metrics/__init__.py index 42af965fa..e6fafecfd 100644 --- a/quantus/metrics/__init__.py +++ b/quantus/metrics/__init__.py @@ -4,9 +4,8 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from quantus.metrics.base import * -from quantus.metrics.base_batched import * -from quantus.metrics.base_perturbed import * +from quantus.metrics.base import Metric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.metrics.axiomatic import * from quantus.metrics.complexity import * from quantus.metrics.faithfulness import * diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index d9d50262e..4c43bb3b1 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -10,7 +10,6 @@ import numpy as np from quantus.helpers import warn -from quantus.helpers import asserts from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices @@ -262,51 +261,32 @@ def __call__( **kwargs, ) - def evaluate_instance( + def evaluate_batch( self, + *, model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> bool: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - : boolean - The evaluation results. - """ - x_baseline = self.perturb_func( - arr=x, - indices=np.arange(0, x.size), - indexed_axes=np.arange(0, x.ndim), - **self.perturb_func_kwargs, - ) + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[bool]: + # TODO: vectorize + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + x_baseline = self.perturb_func( + arr=x, + indices=np.arange(0, x.size), + indexed_axes=np.arange(0, x.ndim), + **self.perturb_func_kwargs, + ) - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) - # Predict on baseline. - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_baseline = float(model.predict(x_input)[:, y]) + # Predict on baseline. + x_input = model.shape_input(x_baseline, x.shape, channel_first=True) + y_pred_baseline = float(model.predict(x_input)[:, y]) - if np.sum(a) == self.output_func(y_pred - y_pred_baseline): - return True - else: - return False + retval.append(np.sum(a) == self.output_func(y_pred - y_pred_baseline)) + return retval diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index 8e3745569..f0816cf27 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -14,7 +14,7 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_shift, perturb_batch -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.enums import ( ModelType, DataType, @@ -23,7 +23,7 @@ ) -class InputInvariance(BatchedPerturbationMetric): +class InputInvariance(PerturbationMetric): """ Implementation of Completeness test by Kindermans et al., 2017. @@ -247,6 +247,7 @@ def evaluate_batch( y_batch: np.ndarray, a_batch: np.ndarray, s_batch: np.ndarray, + **kwargs, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 2b78e54b7..09b2e70d6 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import warn @@ -265,68 +265,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> int: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - integer - The evaluation results. - """ - a = a.flatten() - - non_features = set(list(np.argwhere(a).flatten() < self.eps)) - - vars = [] - for i_ix, a_ix in enumerate(a[:: self.features_in_step]): - - preds = [] - a_ix = a[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ].astype(int) - - for _ in range(self.n_samples): - - # Perturb input by indices of attributions. - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturbed = float(model.predict(x_input)[:, y]) - - preds.append(y_pred_perturbed) - vars.append(np.var(preds)) - - non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) - - return len(non_features_vars.symmetric_difference(non_features)) - def custom_preprocess( self, model: ModelInterface, @@ -363,3 +301,51 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[int]: + retval = [] + + for x, y, a in zip(x_batch, y_batch, a_batch): + a = a.flatten() + + non_features = set(list(np.argwhere(a).flatten() < self.eps)) + + vars = [] + for i_ix, a_ix in enumerate(a[:: self.features_in_step]): + preds = [] + a_ix = a[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ].astype(int) + + for _ in range(self.n_samples): + # Perturb input by indices of attributions. + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x.shape, channel_first=True + ) + y_pred_perturbed = float(model.predict(x_input)[:, y]) + + preds.append(y_pred_perturbed) + vars.append(np.var(preds)) + + non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) + retval.append(len(non_features_vars.symmetric_difference(non_features))) + + return retval diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index cf335653e..7073bcb74 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -5,67 +5,39 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -import inspect -import re +from __future__ import annotations + from abc import abstractmethod -from collections.abc import Sequence -from typing import ( - Any, - Callable, - Dict, - Sequence, - Optional, - Tuple, - Union, - Collection, - List, - Set, -) +from typing import Any, Callable, Dict, Sequence, Optional, ClassVar, Generator, Set + import matplotlib.pyplot as plt import numpy as np +from sklearn.utils import gen_batches from tqdm.auto import tqdm from quantus.helpers import asserts from quantus.helpers import utils from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.model.model_interface import ModelInterface class Metric: """ - Implementation of the base Metric class. + Interface defining Metrics' API. """ - @property - @abstractmethod - def name(self) -> str: - raise NotImplementedError - - @property - @abstractmethod - def evaluation_category(self) -> EvaluationCategory: - raise NotImplementedError - - @property - @abstractmethod - def score_direction(self) -> ScoreDirection: - raise NotImplementedError - - @property - @abstractmethod - def model_applicability(self) -> Set[ModelType]: - return {ModelType.TORCH, ModelType.TF} - - @property - @abstractmethod - def data_applicability(self) -> Set[DataType]: - raise NotImplementedError + name: ClassVar[str] + data_applicability: ClassVar[Set[DataType]] + model_applicability: ClassVar[Set[ModelType]] + score_direction: ClassVar[ScoreDirection] + # can one metric fall into multiple categories? + evaluation_category: ClassVar[EvaluationCategory] def __init__( self, @@ -151,17 +123,17 @@ def __call__( s_batch: Optional[np.ndarray], channel_first: Optional[bool], explain_func: Optional[Callable], - explain_func_kwargs: Optional[Dict], + explain_func_kwargs: Optional[Dict[str, Any]], model_predict_kwargs: Optional[Dict], softmax: Optional[bool], device: Optional[str] = None, batch_size: int = 64, custom_batch: Optional[Any] = None, **kwargs, - ) -> Union[int, float, list, dict, Collection[Any], None]: + ) -> Any: """ This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + It completes batch-wise evaluation of explanations (a_batch) with respect to input data (x_batch), output labels (y_batch) and a torch or tensorflow model (model). Calls general_preprocess() with all relevant arguments, calls @@ -238,35 +210,38 @@ def __call__( >> metric = Metric(abs=True, normalise=False) >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} """ - # Run deprecation warnings. warn.deprecation_warnings(kwargs) warn.check_kwargs(kwargs) - data = self.general_preprocess( + data: Dict[str, Any] = self.general_preprocess( model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch, s_batch=s_batch, + custom_batch=custom_batch, channel_first=channel_first, explain_func=explain_func, explain_func_kwargs=explain_func_kwargs, model_predict_kwargs=model_predict_kwargs, softmax=softmax, device=device, - custom_batch=custom_batch, ) - self.evaluation_scores = [None for _ in x_batch] + # Create generator for generating batches. + batch_generator = self.generate_batches( + data=data, + batch_size=batch_size, + ) - # Evaluate with instance given the metric. - iterator = self.get_instance_iterator(data=data) - for id_instance, data_instance in iterator: - result = self.evaluate_instance(**data_instance) - self.evaluation_scores[id_instance] = result + self.evaluation_scores = [] + for data_batch in batch_generator: + data_batch = self.batch_preprocess(data_batch) + result = self.evaluate_batch(**data_batch) + self.evaluation_scores.extend(result) - # Call custom post-processing. + # Call post-processing. self.custom_postprocess(**data) if self.return_aggregate: @@ -286,21 +261,24 @@ def __call__( "Specify an 'aggregate_func' (Callable) to aggregate evaluation scores." ) + # Append content of last results to all results. self.all_evaluation_scores.append(self.evaluation_scores) return self.evaluation_scores @abstractmethod - def evaluate_instance( + def evaluate_batch( self, + *, model: ModelInterface, - x: np.ndarray, - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - ) -> Any: + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + s_batch: np.ndarray, + **kwargs, + ): """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + Evaluates model and attributes on a single data batch and returns the batched evaluation result. This method needs to be implemented to use __call__(). @@ -308,18 +286,19 @@ def evaluate_instance( ---------- model: ModelInterface A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + s_batch: np.ndarray + The segmentation to be evaluated on a batch-basis. Returns ------- - Any + np.ndarray + The batched evaluation results. """ raise NotImplementedError() @@ -395,9 +374,8 @@ def general_preprocess( channel_first = utils.infer_channel_first(x_batch) x_batch = utils.make_channel_first(x_batch, channel_first) - # Wrap the model into an interface. - if model: - + # TODO: can model be None? + if model is not None: # Use attribute value if not passed explicitly. model = utils.get_wrapped_model( model=model, @@ -417,27 +395,27 @@ def general_preprocess( if device is not None and "device" not in self.explain_func_kwargs: self.explain_func_kwargs["device"] = device - if a_batch is None: - - # Asserts. - asserts.assert_explain_func(explain_func=self.explain_func) + if a_batch is not None: + # If no explanations provided, we compute them ob batch-level to avoid OOM. + a_batch = utils.expand_attribution_channel(a_batch, x_batch) + asserts.assert_attributions(x_batch=x_batch, a_batch=a_batch) + self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) + + # Normalise with specified keyword arguments if requested. + if self.normalise: + # TODO: what is this signature? + a_batch = self.normalise_func( + a_batch, + normalise_axes=list(range(np.ndim(a_batch)))[1:], + **self.normalise_func_kwargs, + ) - # Generate explanations. - a_batch = self.explain_func( - model=model.get_model(), - inputs=x_batch, - targets=y_batch, - **self.explain_func_kwargs, - ) + # Take absolute if requested. + if self.abs: + a_batch = np.abs(a_batch) - # Expand attributions to input dimensionality. - a_batch = utils.expand_attribution_channel(a_batch, x_batch) - - # Asserts. - asserts.assert_attributions(x_batch=x_batch, a_batch=a_batch) - - # Infer attribution axes for perturbation function. - self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) + else: + asserts.assert_explain_func(explain_func=self.explain_func) # Initialize data dictionary. data = { @@ -451,28 +429,14 @@ def general_preprocess( # Call custom pre-processing from inheriting class. custom_preprocess_dict = self.custom_preprocess(**data) - # Save data coming from custom preprocess to data dict. - if custom_preprocess_dict: - for key, value in custom_preprocess_dict.items(): - data[key] = value + if custom_preprocess_dict is not None: + data.update(custom_preprocess_dict) # Remove custom_batch if not used. if data["custom_batch"] is None: del data["custom_batch"] - # Normalise with specified keyword arguments if requested. - if self.normalise: - data["a_batch"] = self.normalise_func( - a=data["a_batch"], - normalise_axes=list(range(np.ndim(data["a_batch"])))[1:], - **self.normalise_func_kwargs, - ) - - # Take absolute if requested. - if self.abs: - data["a_batch"] = np.abs(data["a_batch"]) - return data def custom_preprocess( @@ -618,9 +582,48 @@ def custom_preprocess( """ pass - def get_instance_iterator(self, data: Dict[str, Any]): + def custom_postprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: np.ndarray, + **kwargs, + ) -> Optional[Any]: """ - Creates iterator to iterate over all instances in data dictionary. + Implement this method if you need custom postprocessing of results or + additional attributes. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + kwargs: any, optional + Additional data which was created in custom_preprocess(). + + Returns + ------- + any + Can be implemented, optionally by the child class. + """ + pass + + def generate_batches( + self, + data: Dict[str, Any], + batch_size: int, + ) -> Generator[Dict[str, Any], None, None]: + """ + Creates iterator to iterate over all batched instances in data dictionary. Each iterator output element is a keyword argument dictionary with string keys. @@ -629,7 +632,7 @@ def get_instance_iterator(self, data: Dict[str, Any]): will be written to each iterator output dictionary. - If the item value is a sequence and the item key ends with '_batch', a check will be made to make sure length matches number of instances. - The value of each instance in the sequence will be added to the respective + The values of the batch instances in the sequence will be added to the respective iterator output dictionary with the '_batch' suffix removed. - If the item value is a sequence but doesn't end with '_batch', it will be treated as a simple value and the respective item key/value pair will be @@ -639,6 +642,8 @@ def get_instance_iterator(self, data: Dict[str, Any]): ---------- data: dict[str, any] The data input dictionary. + batch_size: int + The batch size to be used. Returns ------- @@ -648,10 +653,13 @@ def get_instance_iterator(self, data: Dict[str, Any]): """ n_instances = len(data["x_batch"]) - for key, value in data.items(): - # If data-value is not a Sequence or a string, create list of repeated values with length of n_instances. + single_value_kwargs: Dict[str, Any] = {} + batched_value_kwargs: Dict[str, Any] = {} + + for key, value in list(data.items()): + # If data-value is not a Sequence or a string, create list of value with length of n_instances. if not isinstance(value, (Sequence, np.ndarray)) or isinstance(value, str): - data[key] = [value for _ in range(n_instances)] + single_value_kwargs[key] = value # If data-value is a sequence and ends with '_batch', only check for correct length. elif key.endswith("_batch"): @@ -660,71 +668,39 @@ def get_instance_iterator(self, data: Dict[str, Any]): raise ValueError( f"'{key}' has incorrect length (expected: {n_instances}, is: {len(value)})" ) + else: + batched_value_kwargs[key] = value # If data-value is a sequence and doesn't end with '_batch', create # list of repeated sequences with length of n_instances. else: - data[key] = [value for _ in range(n_instances)] - - # We create a list of dictionaries where each dictionary holds all data for a single instance. - # We remove the '_batch' suffix if present. - data_instances = [ - { - re.sub("_batch", "", key): value[id_instance] - for key, value in data.items() - } - for id_instance in range(n_instances) - ] + single_value_kwargs[key] = [value for _ in range(n_instances)] + n_batches = np.ceil(n_instances / batch_size) + + # Create iterator for batch index. iterator = tqdm( - enumerate(data_instances), - total=n_instances, - disable=not self.display_progressbar, # Create progress bar if desired. - desc=f"Evaluating {self.__class__.__name__}", + total=n_batches, + disable=not self.display_progressbar, ) - return iterator - - def custom_postprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - **kwargs, - ) -> Optional[Any]: - """ - Implement this method if you need custom postprocessing of results or - additional attributes. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - kwargs: any, optional - Additional data which was created in custom_preprocess(). - - Returns - ------- - any - Can be implemented, optionally by the child class. - """ - pass + # Iterate over batch index + for batch_idx in gen_batches(n_instances, batch_size): + # Calculate instance index for start and end of batch. + # Create batch dictionary with all specified batch instance values + batch = { + key: value[batch_idx.start : batch_idx.stop] + for key, value in batched_value_kwargs.items() + } + # Yield batch dictionary including single value keyword arguments. + yield {**batch, **single_value_kwargs} + iterator.update(min(batch_size, batch_idx.stop - batch_idx.start)) def plot( self, plot_func: Optional[Callable] = None, show: bool = True, - path_to_save: Union[str, None] = None, + path_to_save: str | None = None, *args, **kwargs, ) -> None: @@ -738,13 +714,14 @@ def plot( A Callable with the actual plotting logic. Default set to None, which implies default_plot_func is set. show: boolean A boolean to state if the plot shall be shown. - path_to_save (str): + path_to_save: (str) A string that specifies the path to save file. args: optional An optional with additional arguments. kwargs: optional An optional dict with additional arguments. + Returns ------- None @@ -765,17 +742,14 @@ def plot( if path_to_save: plt.savefig(fname=path_to_save, dpi=400) - return None - - @property - def interpret_scores(self) -> None: + def interpret_scores(self): """ Get an interpretation of the scores. """ print(self.__init__.__doc__.split(".")[1].split("References")[0]) @property - def get_params(self) -> dict: + def get_params(self) -> Dict[str, Any]: """ List parameters of metric. @@ -796,13 +770,56 @@ def get_params(self) -> dict: @property def last_results(self): print( - "Warning: 'last_results' has been renamed to 'evaluation_scores'. 'last_results' is removed in current version." + "Warning: 'last_results' has been renamed to 'evaluation_scores'. " + "'last_results' is removed in current version." ) return self.evaluation_scores @property def all_results(self): print( - "Warning: 'all_results' has been renamed to 'all_evaluation_scores'. 'all_results' is removed in current version." + "Warning: 'all_results' has been renamed to 'all_evaluation_scores'. " + "'all_results' is removed in current version." ) return self.all_evaluation_scores + + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + x_batch = data_batch["x_batch"] + + a_batch = data_batch.get("a_batch") + + if a_batch is None: + # Generate batch of explanations lazily, so we don't hit OOM + model = data_batch["model"] + y_batch = data_batch["y_batch"] + a_batch = self.explain_batch(model, x_batch, y_batch) + data_batch["a_batch"] = a_batch + + if hasattr(self, "a_axes") and self.a_axes is None: + # TODO: we must not modify global state during evaluation. + self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) + + return data_batch + + def explain_batch( + self, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray + ) -> np.ndarray: + """Compute explanations, normalize and take absolute (if was configured so during metric initialization.)""" + if hasattr(model, "get_model"): + # Sometimes the model is our wrapper, but sometimes raw Keras/Torch model. + model = model.get_model() + + a_batch = self.explain_func( + model=model, inputs=x_batch, targets=y_batch, **self.explain_func_kwargs + ) + a_batch = utils.expand_attribution_channel(a_batch, x_batch) + asserts.assert_attributions(x_batch=x_batch, a_batch=a_batch) + + # Normalise and take absolute values of the attributions, if configured during metric instantiation. + if self.normalise: + a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) + + if self.abs: + a_batch = np.abs(a_batch) + + return a_batch diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index aafe2d6df..a1f2d1885 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -1,550 +1,34 @@ -"""This module implements the base class for creating evaluation metrics.""" - # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -import inspect -import math -import re -from abc import abstractmethod -from typing import Any, Callable, Dict, Optional, Sequence, Union - -import numpy as np -from tqdm.auto import tqdm +import abc +import warnings from quantus.metrics.base import Metric -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface -from quantus.helpers.enums import ( - ModelType, - DataType, - ScoreDirection, - EvaluationCategory, -) - - -class BatchedMetric(Metric): - """ - Implementation base BatchedMetric class. - - Attributes: - - name: The name of the metric. - - data_applicability: The data types that the metric implementation currently supports. - - model_applicability: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "BatchedMetric" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.NONE - - def __init__( - self, - abs: bool, - normalise: bool, - normalise_func: Optional[Callable], - normalise_func_kwargs: Optional[Dict[str, Any]], - return_aggregate: bool, - aggregate_func: Optional[Callable], - default_plot_func: Optional[Callable], - disable_warnings: bool, - display_progressbar: bool, - **kwargs, - ): - """ - Initialise the BatchedMetric base class. - - Each of the defined metrics in Quantus, inherits from Metric or BatchedMetric base class. - - A child metric can benefit from the following class methods: - - __call__(): Will call general_preprocess(), apply evaluate_instance() on each - instance and finally call custom_preprocess(). - To use this method the child BatchedMetric needs to implement - evaluate_instance(). - - general_preprocess(): Prepares all necessary data structures for evaluation. - Will call custom_preprocess() at the end. - - Parameters - ---------- - abs: boolean - Indicates whether absolute operation is applied on the attribution. - normalise: boolean - Indicates whether normalise operation is applied on the attribution. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed. - kwargs: optional - Keyword arguments. - """ - - # Initialise super-class with passed parameters. - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: Optional[np.ndarray], - channel_first: Optional[bool], - explain_func: Optional[Callable], - explain_func_kwargs: Optional[Dict[str, Any]], - model_predict_kwargs: Optional[Dict], - softmax: Optional[bool], - device: Optional[str] = None, - batch_size: int = 64, - custom_batch: Optional[Any] = None, - **kwargs, - ) -> Union[int, float, list, dict, None]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes batch-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - evaluate_instance() on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - The content of evaluation_scores will be appended to all_evaluation_scores (list) at the end of - the evaluation call. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - custom_batch: any - Any object that can be passed to the evaluation process. - Gives flexibility to the user to adapt for implementing their own metric. - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} - """ - # Run deprecation warnings. - warn.deprecation_warnings(kwargs) - warn.check_kwargs(kwargs) - - data = self.general_preprocess( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=custom_batch, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - model_predict_kwargs=model_predict_kwargs, - softmax=softmax, - device=device, - ) - - # Create generator for generating batches. - batch_generator = self.generate_batches( - data=data, - batch_size=batch_size, - ) - - self.evaluation_scores = [] - for data_batch in batch_generator: - result = self.evaluate_batch(**data_batch) - self.evaluation_scores.extend(result) - - # Call post-processing. - self.custom_postprocess(**data) - - if self.return_aggregate: - if self.aggregate_func: - try: - self.evaluation_scores = [ - self.aggregate_func(self.evaluation_scores) - ] - except: - print( - "The aggregation of evaluation scores failed. Check that " - "'aggregate_func' supplied is appropriate for the data " - "in 'evaluation_scores'." - ) - else: - raise KeyError( - "Specify an 'aggregate_func' (Callable) to aggregate evaluation scores." - ) - - # Append content of last results to all results. - self.all_evaluation_scores.append(self.evaluation_scores) +from quantus.metrics.base_perturbed import PerturbationMetric - return self.evaluation_scores - @abstractmethod - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - s_batch: np.ndarray, - ): - """ - Evaluates model and attributes on a single data batch and returns the batched evaluation result. +"""Aliases to smoothen transition to uniform metric API.""" - This method needs to be implemented to use __call__(). - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - s_batch: np.ndarray - The segmentation to be evaluated on a batch-basis. +class BatchedMetric(Metric, abc.ABC): - Returns - ------- - np.ndarray - The batched evaluation results. - """ - raise NotImplementedError() + """Alias to quantus.Metric, will be removed in next major release.""" - @staticmethod - def get_number_of_batches(n_instances: int, batch_size: int) -> int: - """ - Get the number of batches given number of samples/ instances and a batch size. - - Parameters - ---------- - n_instances: int - The number of instances. - batch_size: int - The batch size. - - Returns - ------- - integer - """ - return math.ceil(n_instances / batch_size) - - def generate_batches( - self, - data: Dict[str, Any], - batch_size: int, - ): - """ - Creates iterator to iterate over all batched instances in data dictionary. - Each iterator output element is a keyword argument dictionary with - string keys. - - Each item key in the input data dictionary has to be of type string. - - If the item value is not a sequence, the respective item key/value pair - will be written to each iterator output dictionary. - - If the item value is a sequence and the item key ends with '_batch', - a check will be made to make sure length matches number of instances. - The values of the batch instances in the sequence will be added to the respective - iterator output dictionary with the '_batch' suffix removed. - - If the item value is a sequence but doesn't end with '_batch', it will be treated - as a simple value and the respective item key/value pair will be - written to each iterator output dictionary. - - Parameters - ---------- - data: dict[str, any] - The data input dictionary. - batch_size: int - The batch size to be used. - - Returns - ------- - iterator - Each iterator output element is a keyword argument dictionary (string keys). - - """ - n_instances = len(data["x_batch"]) - - single_value_kwargs: Dict[str, Any] = {} - batched_value_kwargs: Dict[str, Any] = {} - - for key, value in list(data.items()): - # If data-value is not a Sequence or a string, create list of value with length of n_instances. - if not isinstance(value, (Sequence, np.ndarray)) or isinstance(value, str): - single_value_kwargs[key] = value - - # If data-value is a sequence and ends with '_batch', only check for correct length. - elif key.endswith("_batch"): - if len(value) != n_instances: - # Sequence has to have correct length. - raise ValueError( - f"'{key}' has incorrect length (expected: {n_instances}, is: {len(value)})" - ) - else: - batched_value_kwargs[key] = value - - # If data-value is a sequence and doesn't end with '_batch', create - # list of repeated sequences with length of n_instances. - else: - single_value_kwargs[key] = [value for _ in range(n_instances)] - - n_batches = self.get_number_of_batches( - n_instances=n_instances, batch_size=batch_size - ) - - # Create iterator for batch index. - iterator = tqdm( - range(0, n_batches), - total=n_batches, - disable=not self.display_progressbar, - ) - - # Iterate over batch index - for batch_idx in iterator: - # Calculate instance index for start and end of batch. - batch_start = batch_size * batch_idx - batch_end = min(batch_size * (batch_idx + 1), n_instances) - - # Create batch dictionary with all specified batch instance values - batch = { - key: value[batch_start:batch_end] - for key, value in batched_value_kwargs.items() - } - - # Yield batch dictionary including single value keyword arguments. - yield {**batch, **single_value_kwargs} - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - **kwargs, - ) -> Any: - """ - This method from the parent Metric class needs to be defined to implement this abstract class. - However we use evalaute_batch() instead for BatchedMetric. - - Returns - ------- - Any - """ - raise NotImplementedError( - "evaluate_instance() not implemented for BatchedMetric" - ) - - -class BatchedPerturbationMetric(BatchedMetric): - """ - Implementation base BatchedPertubationMetric class. - - This batched metric has additional attributes for perturbations. - """ - - def __init__( - self, - abs: bool, - normalise: bool, - normalise_func: Optional[Callable], - normalise_func_kwargs: Optional[Dict[str, Any]], - perturb_func: Callable, - perturb_func_kwargs: Optional[Dict[str, Any]], - return_aggregate: bool, - aggregate_func: Optional[Callable], - default_plot_func: Optional[Callable], - disable_warnings: bool, - display_progressbar: bool, - **kwargs, - ): - """ - Initialise the PerturbationMetric base class. - - Parameters - ---------- - abs: boolean - Indicates whether absolute operation is applied on the attribution. - normalise: boolean - Indicates whether normalise operation is applied on the attribution. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call. - perturb_func: callable - Input perturbation function. - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call.. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed. - kwargs: optional - Keyword arguments. - """ - - # Initialise super-class with passed parameters. - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, + def __subclasscheck__(self, subclass): + warnings.warn( + "BatchedMetric was deprecated, since it is just an alias to Metric. Please subclass Metric directly." ) - # Save perturbation metric attributes. - self.perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - self.perturb_func_kwargs = perturb_func_kwargs - - @abstractmethod - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - s_batch: np.ndarray, - ) -> np.ndarray: - """ - Evaluates model and attributes on a single data batch and returns the batched evaluation result. - - This method needs to be implemented to use __call__(). - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - s_batch: np.ndarray - The segmentation to be evaluated on a batch-basis. - - Returns - ------- - np.ndarray - The batched evaluation results. - """ - raise NotImplementedError() - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - **kwargs, - ) -> Any: - """ - This method from the parent Metric class needs to be defined to implement this abstract class. - However we use evalaute_batch() instead for BatchedMetric. +class BatchedPerturbationMetric(PerturbationMetric, abc.ABC): + """Alias to quantus.PerturbationMetric, will be removed in next major release.""" - Parameters - ---------- - kwargs: optional - Keyword arguments. - """ - raise NotImplementedError( - "evaluate_instance() not implemented for BatchedPerturbationMetric" + def __subclasscheck__(self, subclass): + warnings.warn( + "BatchedPerturbationMetric was deprecated, " + "since it is just an alias to Metric. Please subclass PerturbationMetric directly." ) diff --git a/quantus/metrics/base_perturbed.py b/quantus/metrics/base_perturbed.py index 479a87dc7..3c161501d 100644 --- a/quantus/metrics/base_perturbed.py +++ b/quantus/metrics/base_perturbed.py @@ -5,39 +5,18 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -import inspect -import re -from abc import abstractmethod -from collections.abc import Sequence +from abc import ABC from typing import ( Any, Callable, Dict, - Sequence, Optional, - Tuple, - Union, - Collection, - List, ) -import matplotlib.pyplot as plt -import numpy as np -from tqdm.auto import tqdm -from quantus.helpers import asserts -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric -from quantus.helpers.enums import ( - ModelType, - DataType, - ScoreDirection, - EvaluationCategory, -) -class PerturbationMetric(Metric): +class PerturbationMetric(Metric, ABC): """ Implementation base PertubationMetric class. @@ -52,13 +31,6 @@ class PerturbationMetric(Metric): - evaluation_category: What property/ explanation quality that this metric measures. """ - name = "PerturbationMetric" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.NONE - - def __init__( self, abs: bool, @@ -121,42 +93,10 @@ def __init__( **kwargs, ) + # TODO: do we really need separate 150+ lines long class just to reuse 4 lines of code? # Save perturbation metric attributes. self.perturb_func = perturb_func if perturb_func_kwargs is None: perturb_func_kwargs = {} self.perturb_func_kwargs = perturb_func_kwargs - - @abstractmethod - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - ) -> Any: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - This method needs to be implemented to use __call__(). - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - Any - """ - raise NotImplementedError() diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 021b2f3fe..1d2e12bb9 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -10,9 +10,7 @@ import numpy as np import scipy -from quantus.helpers import asserts from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.metrics.base import Metric from quantus.helpers.enums import ( @@ -227,40 +225,18 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ " - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - if len(x.shape) == 1: - newshape = np.prod(x.shape) - else: - newshape = np.prod(x.shape[1:]) - - a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a)) - return scipy.stats.entropy(pk=a) + def evaluate_batch( + self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ + ) -> List[float]: + # TODO: vectorize + retval = [] + for x, a in zip(x_batch, a_batch): + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a)) + retval.append(scipy.stats.entropy(pk=a)) + + return retval diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 18307d150..6eee89e7a 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -9,9 +9,7 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.metrics.base import Metric from quantus.helpers.enums import ( @@ -230,35 +228,12 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> int: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. + def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: + # TODO: vectorize + retval = [] - Returns - ------- - integer - The evaluation results. - """ + for a in a_batch: + a = a.flatten() + retval.append(int(np.sum(a > self.eps))) - a = a.flatten() - return int(np.sum(a > self.eps)) + return retval diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 564fefca8..3cac9324c 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -9,9 +9,7 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.metrics.base import Metric from quantus.helpers.enums import ( @@ -232,44 +230,23 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - if len(x.shape) == 1: - newshape = np.prod(x.shape) - else: - newshape = np.prod(x.shape[1:]) - - a = np.array(np.reshape(a, newshape), dtype=np.float64) - a += 0.0000001 - a = np.sort(a) - score = (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / ( - a.shape[0] * np.sum(a) - ) - return score + def evaluate_batch( + self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + + for x, a in zip(x_batch, a_batch): + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a = np.array(np.reshape(a, newshape), dtype=np.float64) + a += 0.0000001 + a = np.sort(a) + score = ( + np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a) + ) / (a.shape[0] * np.sum(a)) + retval.append(score) + + return retval diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 28f60409f..79c3a2b7c 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -276,70 +276,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Flatten the attributions. - a = a.flatten() - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - pred_deltas = [] - att_sums = [] - - # For each test data point, execute a couple of runs. - for i_ix in range(self.nr_runs): - - # Randomly mask by subset size. - a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(float(y_pred - y_pred_perturb)) - - # Sum attributions of the random subset. - att_sums.append(np.sum(a[a_ix])) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - - return similarity - def custom_preprocess( self, model: ModelInterface, @@ -377,3 +313,49 @@ def custom_preprocess( asserts.assert_value_smaller_than_input_size( x=x_batch, value=self.subset_size, value_name="subset_size" ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[float]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + a = a.flatten() + + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + pred_deltas = [] + att_sums = [] + + # For each test data point, execute a couple of runs. + for i_ix in range(self.nr_runs): + # Randomly mask by subset size. + a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas.append(float(y_pred - y_pred_perturb)) + + # Sum attributions of the random subset. + att_sums.append(np.sum(a[a_ix])) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas) + + retval.append(similarity) + + return retval diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 5d40b4393..a4c743902 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -261,75 +261,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Flatten the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - pred_deltas = [None for _ in range(n_perturbations)] - att_sums = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas[i_ix] = float(y_pred - y_pred_perturb) - - # Sum attributions. - att_sums[i_ix] = np.sum(a[a_ix]) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - return similarity - def custom_preprocess( self, model: ModelInterface, @@ -368,3 +299,57 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ): + rerval = [] + + for x, y, a in zip(x_batch, y_batch, a_batch): + # Flatten the attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + pred_deltas = [None for _ in range(n_perturbations)] + att_sums = [None for _ in range(n_perturbations)] + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ] + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas[i_ix] = float(y_pred - y_pred_perturb) + + # Sum attributions. + att_sums[i_ix] = np.sum(a[a_ix]) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas) + rerval.append(similarity) + + return rerval diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index ffa3df8bb..43cd5010d 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -from quantus.helpers import asserts from quantus.helpers import utils from quantus.helpers import warn from quantus.functions.loss_func import mse @@ -283,108 +282,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - results = [] - - for _ in range(self.n_perturb_samples): - - sub_results = [] - - for patch_size in self.perturb_patch_sizes: - - pred_deltas = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - a_sums = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - x_perturbed = x.copy() - pad_width = patch_size - 1 - - for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): - - for i_y, top_left_y in enumerate(range(0, x.shape[2], patch_size)): - - # Perturb input patch-wise. - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - patch_slice = utils.create_patch_slice( - patch_size=patch_size, - coords=[top_left_x, top_left_y], - ) - - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - # Predict on perturbed input x_perturbed. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - warn.warn_perturbation_caused_no_change( - x=x, x_perturbed=x_input - ) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - x_diff = x - x_perturbed - a_diff = np.dot( - np.repeat(a, repeats=self.nr_channels, axis=0), x_diff - ) - - pred_deltas[i_x][i_y] = y_pred - y_pred_perturb - a_sums[i_x][i_y] = np.sum(a_diff) - - assert callable(self.loss_func) - sub_results.append( - self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) - ) - - results.append(np.mean(sub_results)) - - return np.mean(results) - def custom_preprocess( self, model: ModelInterface, @@ -418,3 +315,91 @@ def custom_preprocess( """ # Infer number of input channels. self.nr_channels = x_batch.shape[1] + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[List[float]]: + # TODO: can we get rid of any for-loop? + retval = [] + + for x, y, a in zip(x_batch, y_batch, a_batch): + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + results = [] + + for _ in range(self.n_perturb_samples): + sub_results = [] + + for patch_size in self.perturb_patch_sizes: + pred_deltas = np.zeros( + (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) + ) + a_sums = np.zeros( + (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) + ) + x_perturbed = x.copy() + pad_width = patch_size - 1 + + for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): + for i_y, top_left_y in enumerate( + range(0, x.shape[2], patch_size) + ): + # Perturb input patch-wise. + x_perturbed_pad = utils._pad_array( + x_perturbed, + pad_width, + mode="edge", + padded_axes=self.a_axes, + ) + patch_slice = utils.create_patch_slice( + patch_size=patch_size, + coords=[top_left_x, top_left_y], + ) + + x_perturbed_pad = self.perturb_func( + arr=x_perturbed_pad, + indices=patch_slice, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Remove padding. + x_perturbed = utils._unpad_array( + x_perturbed_pad, pad_width, padded_axes=self.a_axes + ) + + # Predict on perturbed input x_perturbed. + x_input = model.shape_input( + x_perturbed, x.shape, channel_first=True + ) + warn.warn_perturbation_caused_no_change( + x=x, x_perturbed=x_input + ) + y_pred_perturb = float(model.predict(x_input)[:, y]) + + x_diff = x - x_perturbed + a_diff = np.dot( + np.repeat(a, repeats=self.nr_channels, axis=0), x_diff + ) + + pred_deltas[i_x][i_y] = y_pred - y_pred_perturb + a_sums[i_x][i_y] = np.sum(a_diff) + + assert callable(self.loss_func) + sub_results.append( + self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) + ) + + results.append(np.mean(sub_results)) + + retval.append(results) + + return retval diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 829c58d59..73916e596 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -262,85 +262,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - # Segment image. - segments = utils.get_superpixel_segments( - img=np.moveaxis(x, 0, -1).astype("double"), - segmentation_method=self.segmentation_method, - ) - nr_segments = len(np.unique(segments)) - asserts.assert_nr_segments(nr_segments=nr_segments) - - # Calculate average attribution of each segment. - att_segs = np.zeros(nr_segments) - for i, s in enumerate(range(nr_segments)): - att_segs[i] = np.mean(a[:, segments == s]) - - # Sort segments based on the mean attribution (descending order). - s_indices = np.argsort(-att_segs) - - preds = [] - x_prev_perturbed = x - - for i_ix, s_ix in enumerate(s_indices): - - # Perturb input by indices of attributions. - a_ix = np.nonzero((segments == s_ix).flatten())[0] - - x_perturbed = self.perturb_func( - arr=x_prev_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change( - x=x_prev_perturbed, x_perturbed=x_perturbed - ) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - # Normalise the scores to be within range [0, 1]. - preds.append(float(y_pred_perturb / y_pred)) - x_prev_perturbed = x_perturbed - - # Calculate the area over the curve (AOC) score. - aoc = len(preds) - utils.calculate_auc(np.array(preds)) - return aoc - def custom_preprocess( self, model: ModelInterface, @@ -379,3 +300,65 @@ def custom_preprocess( def get_aoc_score(self): """Calculate the area over the curve (AOC) score for several test samples.""" return np.mean(self.evaluation_scores) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[float]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Predict on x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + # Segment image. + segments = utils.get_superpixel_segments( + img=np.moveaxis(x, 0, -1).astype("double"), + segmentation_method=self.segmentation_method, + ) + nr_segments = len(np.unique(segments)) + asserts.assert_nr_segments(nr_segments=nr_segments) + + # Calculate average attribution of each segment. + att_segs = np.zeros(nr_segments) + for i, s in enumerate(range(nr_segments)): + att_segs[i] = np.mean(a[:, segments == s]) + + # Sort segments based on the mean attribution (descending order). + s_indices = np.argsort(-att_segs) + + preds = [] + x_prev_perturbed = x + + for i_ix, s_ix in enumerate(s_indices): + # Perturb input by indices of attributions. + a_ix = np.nonzero((segments == s_ix).flatten())[0] + + x_perturbed = self.perturb_func( + arr=x_prev_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change( + x=x_prev_perturbed, x_perturbed=x_perturbed + ) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + + # Normalise the scores to be within range [0, 1]. + preds.append(float(y_pred_perturb / y_pred)) + x_prev_perturbed = x_perturbed + + # Calculate the area over the curve (AOC) score. + aoc = len(preds) - utils.calculate_auc(np.array(preds)) + retval.append(aoc) + + return retval diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 51335c5dc..7acfd318a 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -258,72 +258,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Prepare shapes. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - - # Copy the input x but fill with baseline values. - baseline_value = utils.get_baseline_value( - value=self.perturb_func_kwargs["perturb_baseline"], - arr=x, - return_shape=x.shape, # TODO. Double-check this over using = (1,). - ) - x_baseline = np.full(x.shape, baseline_value) - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_baseline = self.perturb_func( - arr=x_baseline, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - return np.all(np.diff(preds) >= 0) - def custom_preprocess( self, model: ModelInterface, @@ -360,3 +294,54 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[bool]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Prepare shapes. + a = a.flatten() + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + preds = [None for _ in range(n_perturbations)] + + # Copy the input x but fill with baseline values. + baseline_value = utils.get_baseline_value( + value=self.perturb_func_kwargs["perturb_baseline"], + arr=x, + return_shape=x.shape, # TODO. Double-check this over using = (1,). + ) + x_baseline = np.full(x.shape, baseline_value) + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ] + x_baseline = self.perturb_func( + arr=x_baseline, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). + x_input = model.shape_input(x_baseline, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + preds[i_ix] = y_pred_perturb + + retval.append(np.all(np.diff(preds) >= 0)) + + return retval diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index c2ede00b3..f52896055 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -275,83 +275,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Predict on input x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred ** 2 - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - atts = [None for _ in range(n_perturbations)] - vars = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - - y_pred_perturbs = [] - - for s_ix in range(self.nr_samples): - - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - y_pred_perturbs.append(y_pred_perturb) - - vars[i_ix] = float( - np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) * inv_pred - ) - atts[i_ix] = float(sum(a[a_ix])) - - return self.similarity_func(a=atts, b=vars) - def custom_preprocess( self, model: ModelInterface, @@ -388,3 +311,69 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[float]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Predict on input x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) + inv_pred = inv_pred**2 + + # Reshape attributions. + a = a.flatten() + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + atts = [None for _ in range(n_perturbations)] + vars = [None for _ in range(n_perturbations)] + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ] + + y_pred_perturbs = [] + + for s_ix in range(self.nr_samples): + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change( + x=x, x_perturbed=x_perturbed + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x.shape, channel_first=True + ) + y_pred_perturb = float(model.predict(x_input)[:, y]) + y_pred_perturbs.append(y_pred_perturb) + + vars[i_ix] = float( + np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) + * inv_pred + ) + atts[i_ix] = float(sum(a[a_ix])) + + retval.append(self.similarity_func(a=atts, b=vars)) + + return retval diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index cc29994a6..4540591b9 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -6,7 +6,8 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from __future__ import annotations +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -259,71 +260,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> List[float]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - list - The evaluation results. - """ - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Prepare lists. - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - if self.return_auc_per_sample: - return utils.calculate_auc(preds) - - return preds - def custom_preprocess( self, model: ModelInterface, @@ -367,3 +303,52 @@ def get_auc_score(self): return np.mean( [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[float | np.ndarray]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Reshape attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Prepare lists. + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + preds = [None for _ in range(n_perturbations)] + x_perturbed = x.copy() + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ] + x_perturbed = self.perturb_func( + arr=x_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + preds[i_ix] = y_pred_perturb + + if self.return_auc_per_sample: + retval.append(utils.calculate_auc(preds)) + else: + retval.append(preds) + + return retval diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 425b5ebb6..412e8a3ae 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -285,7 +285,6 @@ def evaluate_instance( x: np.ndarray, y: np.ndarray, a: np.ndarray, - s: np.ndarray, ) -> List[float]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -327,7 +326,6 @@ def evaluate_instance( range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes ] for top_left_coords in itertools.product(*axis_iterators): - # Create slice for patch. patch_slice = utils.create_patch_slice( patch_size=self.patch_size, @@ -388,7 +386,6 @@ def evaluate_instance( # Increasingly perturb the input and store the decrease in function value. results = [None for _ in range(len(ordered_patches_no_overlap))] for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should probably depend on the used perturb_func? x_perturbed_pad = utils._pad_array( x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes @@ -423,3 +420,17 @@ def get_auc_score(self): return np.mean( [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[List[float]]: + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 5f4cbc22a..a5bcec557 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -6,11 +6,10 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts from quantus.helpers import warn from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max @@ -268,7 +267,6 @@ def evaluate_instance( x: np.ndarray, y: np.ndarray, a: np.ndarray, - s: np.ndarray, ) -> List[float]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -291,65 +289,6 @@ def evaluate_instance( : list The evaluation results. """ - # Order indices. - ordered_indices = np.argsort(a, axis=None)[::-1] - - results_instance = np.array([None for _ in self.percentages]) - - for p_ix, p in enumerate(self.percentages): - top_k_indices = ordered_indices[: int(self.a_size * p / 100)] - - x_perturbed = self.perturb_func( - arr=x, - indices=top_k_indices, - **self.perturb_func_kwargs, - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - class_pred_perturb = np.argmax(model.predict(x_input)) - - # Write a boolean into the percentage results. - results_instance[p_ix] = int(y == class_pred_perturb) - - # Return list of booleans for each percentage. - return results_instance - - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - - Returns - ------- - None - """ - # Infer the size of attributions. - self.a_size = a_batch[0, :, :].size def custom_postprocess( self, @@ -386,3 +325,48 @@ def custom_postprocess( percentage: np.mean(np.array(self.evaluation_scores)[:, p_ix]) for p_ix, percentage in enumerate(self.percentages) } + + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + data_batch = super().batch_preprocess(data_batch) + # Infer the size of attributions. + self.a_size = data_batch["a_batch"][0, :, :].size + return data_batch + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ): + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Order indices. + ordered_indices = np.argsort(a, axis=None)[::-1] + + results_instance = np.array([None for _ in self.percentages]) + + for p_ix, p in enumerate(self.percentages): + top_k_indices = ordered_indices[: int(self.a_size * p / 100)] + + x_perturbed = self.perturb_func( + arr=x, + indices=top_k_indices, + **self.perturb_func_kwargs, + ) + + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x and store the difference from predicting on unperturbed input. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + class_pred_perturb = np.argmax(model.predict(x_input)) + + # Write a boolean into the percentage results. + results_instance[p_ix] = int(y == class_pred_perturb) + + # Return list of booleans for each percentage. + retval.append(results_instance) + + return retval diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 3741d8ea1..d4d4a9ba1 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.helpers import asserts from quantus.helpers import plotting from quantus.helpers import utils from quantus.helpers import warn @@ -274,7 +273,6 @@ def evaluate_instance( x: np.ndarray, y: np.ndarray, a: np.ndarray, - s: np.ndarray, ) -> List[float]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -289,8 +287,6 @@ def evaluate_instance( The output to be evaluated on an instance-basis. a: np.ndarray The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. Returns ------- @@ -316,7 +312,6 @@ def evaluate_instance( range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes ] for top_left_coords in itertools.product(*axis_iterators): - # Create slice for patch. patch_slice = utils.create_patch_slice( patch_size=self.patch_size, @@ -348,7 +343,6 @@ def evaluate_instance( # Increasingly perturb the input and store the decrease in function value. results = np.array([None for _ in range(len(ordered_patches_no_overlap))]) for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should depend on the used perturb_func. x_perturbed_pad = utils._pad_array( x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes @@ -385,3 +379,17 @@ def get_auc_score(self): for i, curve in enumerate(self.evaluation_scores) ] ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[List[float]]: + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 614fe650d..df2685405 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -273,74 +273,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> Dict[str, List[float]]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - (Dict[str, List[float]]): The evaluation results. - """ - - # Reshape the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - att_sums = [] - pred_deltas = [] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Sum attributions. - att_sums.append(float(a[a_ix].sum())) - - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(y_pred - y_pred_perturb) - - # Each list-element of self.evaluation_scores will be such a dictionary - # We will unpack that later in custom_postprocess(). - return {"att_sums": att_sums, "pred_deltas": pred_deltas} - def custom_preprocess( self, model: ModelInterface, @@ -436,3 +368,56 @@ def custom_postprocess( ) for k in range(max_features) ] + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[Dict[str, List[float]]]: + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Reshape the attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Predict on x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + att_sums = [] + pred_deltas = [] + x_perturbed = x.copy() + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : ( + self.features_in_step * (i_ix + 1) + ) + ] + x_perturbed = self.perturb_func( + arr=x_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Sum attributions. + att_sums.append(float(a[a_ix].sum())) + + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas.append(y_pred - y_pred_perturb) + + # Each list-element of self.evaluation_scores will be such a dictionary + # We will unpack that later in custom_postprocess(). + retval.append({"att_sums": att_sums, "pred_deltas": pred_deltas}) + + return retval diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 22ec96245..0820ae974 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -6,14 +6,12 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, no_type_check import numpy as np from scipy.spatial.distance import cdist -from quantus.helpers import asserts from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.metrics.base import Metric from quantus.helpers.enums import ( @@ -246,88 +244,11 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - i: int = None, - a_sim_vector: np.ndarray = None, - y_pred_classes: np.ndarray = None, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - i: int - The index of the current instance. - a_sim_vector: any - The custom input to be evaluated on an instance-basis. - y_pred_classes: np,ndarray - The class predictions of the complete input dataset. - - Returns - ------- - float - The evaluation results. - """ - - # Metric logic. - pred_a = y_pred_classes[i] - low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() - low_dist_a = low_dist_a[low_dist_a != i] - pred_low_dist_a = y_pred_classes[low_dist_a] - - if len(low_dist_a) == 0: - return 0 - return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) - - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> Dict[str, Any]: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - - Returns - ------- - dictionary[str, np.ndarray] - Output dictionary with 'a_sim_vector_batch' as and attributtion similarity matrix as value. - """ - + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + data_batch = super().batch_preprocess(data_batch) + model = data_batch["model"] + x_batch = data_batch["x_batch"] + a_batch = data_batch["a_batch"] a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) dist_matrix = self.normalise_func(dist_matrix) @@ -340,8 +261,30 @@ def custom_preprocess( ) y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() - return { + custom_batch = { "i_batch": np.arange(x_batch.shape[0]), "a_sim_vector_batch": a_sim_matrix, "y_pred_classes": y_pred_classes, } + + data_batch.update(custom_batch) + return data_batch + + @no_type_check + def evaluate_batch( + self, *, i_batch, a_sim_vector_batch, y_pred_classes, **_ + ) -> List[float]: + retval = [] + for i, a_sim_vector in zip(i_batch, a_sim_vector_batch): + # Metric logic. + pred_a = y_pred_classes[i] + low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() + low_dist_a = low_dist_a[low_dist_a != i] + pred_low_dist_a = y_pred_classes[low_dist_a] + + if len(low_dist_a) == 0: + retval.append(0.0) + else: + retval.append((pred_low_dist_a == pred_a) / len(low_dist_a)) + + return retval diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index a47673661..c1d0c8201 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -48,7 +48,6 @@ class AttributionLocalisation(Metric): score_direction = ScoreDirection.HIGHER evaluation_category = EvaluationCategory.LOCALISATION - def __init__( self, weighted: bool = False, @@ -241,64 +240,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Compute ratio. - size_bbox = float(np.sum(s)) - size_data = np.prod(x.shape[1:]) - ratio = size_bbox / size_data - - # Compute inside/outside ratio. - inside_attribution = np.sum(a[s]) - total_attribution = np.sum(a) - inside_attribution_ratio = float(inside_attribution / total_attribution) - - if not ratio <= self.max_size: - warn.warn_max_size() - if inside_attribution_ratio > 1.0: - warn.warn_segmentation(inside_attribution, total_attribution) - return np.nan - if not self.weighted: - return inside_attribution_ratio - else: - return float(inside_attribution_ratio * ratio) - def custom_preprocess( self, model: ModelInterface, @@ -332,3 +273,39 @@ def custom_preprocess( """ # Asserts. asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) + + def evaluate_batch( + self, *, x_batch: np.ndarray, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + for x, a, s in zip(x_batch, a_batch, s_batch): + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Compute ratio. + size_bbox = float(np.sum(s)) + size_data = np.prod(x.shape[1:]) + ratio = size_bbox / size_data + + # Compute inside/outside ratio. + inside_attribution = np.sum(a[s]) + total_attribution = np.sum(a) + inside_attribution_ratio = float(inside_attribution / total_attribution) + + if not ratio <= self.max_size: + warn.warn_max_size() + if inside_attribution_ratio > 1.0: + warn.warn_segmentation(inside_attribution, total_attribution) + retval.append(np.nan) + elif not self.weighted: + retval.append(inside_attribution_ratio) + else: + retval.append(float(inside_attribution_ratio * ratio)) + + return retval diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 50e1b1b2d..f32b882e3 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from sklearn.metrics import roc_curve, auc @@ -218,49 +218,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: (ModelInteface) - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - : float - The evaluation results. - """ - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - fpr, tpr, _ = roc_curve(y_true=s, y_score=a) - score = auc(x=fpr, y=tpr) - - return score - def custom_preprocess( self, model: ModelInterface, @@ -294,3 +251,24 @@ def custom_preprocess( """ # Asserts. asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) + + def evaluate_batch( + self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + # TODO: vectorize + for a, s in zip(a_batch, s_batch): + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + fpr, tpr, _ = roc_curve(y_true=s, y_score=a) + score = auc(x=fpr, y=tpr) + retval.append(score) + + return retval diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 795199517..5f71f5ef0 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -6,10 +6,9 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, no_type_check import numpy as np -from quantus.helpers import asserts from quantus.helpers import plotting from quantus.helpers import warn from quantus.helpers.model.model_interface import ModelInterface @@ -263,64 +262,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - c: np.ndarray = None, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - c: any - The custom input to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Prepare shapes for mosaics. - self.mosaic_shape = a.shape - - total_positive_relevance = np.sum(a[a > 0], dtype=np.float64) - target_positive_relevance = 0 - - quadrant_functions_list = [ - self.quadrant_top_left, - self.quadrant_top_right, - self.quadrant_bottom_left, - self.quadrant_bottom_right, - ] - - for quadrant_p, quadrant_func in zip(c, quadrant_functions_list): - if not bool(quadrant_p): - continue - quadrant_relevance = quadrant_func(a=a) - target_positive_relevance += np.sum( - quadrant_relevance[quadrant_relevance > 0] - ) - - focus_score = target_positive_relevance / total_positive_relevance - - return focus_score - def custom_preprocess( self, model: ModelInterface, @@ -394,3 +335,40 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: :, int(self.mosaic_shape[1] / 2) :, int(self.mosaic_shape[2] / 2) : ] return quandrant_a + + @no_type_check + def evaluate_batch( + self, + *, + a_batch: np.ndarray, + c_batch: np.ndarray, + **_, + ): + retval = [] + for a, c in zip(a_batch, c_batch): + # Prepare shapes for mosaics. + self.mosaic_shape = a.shape + + total_positive_relevance = np.sum(a[a > 0], dtype=np.float64) + target_positive_relevance = 0 + + quadrant_functions_list = [ + self.quadrant_top_left, + self.quadrant_top_right, + self.quadrant_bottom_left, + self.quadrant_bottom_right, + ] + + for quadrant_p, quadrant_func in zip(c, quadrant_functions_list): + if not bool(quadrant_p): + continue + quadrant_relevance = quadrant_func(a=a) + target_positive_relevance += np.sum( + quadrant_relevance[quadrant_relevance > 0] + ) + + focus_score = target_positive_relevance / total_positive_relevance + + retval.append(focus_score) + + return retval diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index ff6ec0fae..1deb3f696 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -230,56 +230,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> bool: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - boolean - The evaluation results. - """ - - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Find indices with max value. - max_index = np.argwhere(a == np.max(a)) - - # Check if maximum of explanation is on target object class. - hit = np.any(s[max_index]) - - if self.weighted and hit: - hit = 1 - (np.sum(s) / float(np.prod(s.shape))) - - return hit - def custom_preprocess( self, model: ModelInterface, @@ -314,3 +264,32 @@ def custom_preprocess( # Asserts. asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) + + def evaluate_batch( + self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + + + retval = [] + for a, s in zip(a_batch, s_batch): + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Find indices with max value. + max_index = np.argwhere(a == np.max(a)) + + # Check if maximum of explanation is on target object class. + hit = np.any(s[max_index]) + + if self.weighted and hit: + hit = 1 - (np.sum(s) / float(np.prod(s.shape))) + + retval.append(hit) + + return retval \ No newline at end of file diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 7219e98d3..88cec2914 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -224,53 +224,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Compute inside/outside ratio. - r_within = np.sum(a[s]) - r_total = np.sum(a) - - # Calculate mass accuracy. - mass_accuracy = r_within / r_total - - return mass_accuracy - def custom_preprocess( self, model: ModelInterface, @@ -304,3 +257,29 @@ def custom_preprocess( """ # Asserts. asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) + + def evaluate_batch( + self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + for a, s in zip(a_batch, s_batch): + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Compute inside/outside ratio. + r_within = np.sum(a[s]) + r_total = np.sum(a) + + # Calculate mass accuracy. + mass_accuracy = r_within / r_total + + retval.append(mass_accuracy) + + return retval diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index a17f4508a..9d87b604f 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -226,60 +226,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - a = a.flatten() - s = np.where(s.flatten().astype(bool))[0] - - # Size of the ground truth mask. - k = len(s) - - # Sort in descending order. - a_sorted = np.argsort(a)[-int(k) :] - - # Calculate hits. - hits = len(np.intersect1d(s, a_sorted)) - - if hits != 0: - rank_accuracy = hits / float(k) - else: - rank_accuracy = 0.0 - - return rank_accuracy - def custom_preprocess( self, model: ModelInterface, @@ -313,3 +259,37 @@ def custom_preprocess( """ # Asserts. asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) + + def evaluate_batch( + self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + + for a, s in zip(a_batch, s_batch): + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + a = a.flatten() + s = np.where(s.flatten().astype(bool))[0] + + # Size of the ground truth mask. + k = len(s) + + # Sort in descending order. + a_sorted = np.argsort(a)[-int(k) :] + + # Calculate hits. + hits = len(np.intersect1d(s, a_sorted)) + + if hits != 0: + rank_accuracy = hits / float(k) + else: + rank_accuracy = 0.0 + + retval.append(rank_accuracy) + + return retval diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 2b5a16b7b..91340f55b 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -235,58 +235,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ): - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - if np.sum(s) == 0: - warn.warn_empty_segmentation() - return np.nan - - # Prepare shapes. - s = s.astype(bool) - top_k_binary_mask = np.zeros(a.shape) - - # Sort and create masks. - sorted_indices = np.argsort(a, axis=None) - np.put_along_axis(top_k_binary_mask, sorted_indices[-self.k :], 1, axis=None) - top_k_binary_mask = top_k_binary_mask.astype(bool) - - # Top-k intersection. - tki = 1.0 / self.k * np.sum(np.logical_and(s, top_k_binary_mask)) - - # Concept influence (with size of object normalised tki score). - if self.concept_influence: - tki = np.prod(s.shape) / np.sum(s) * tki - - return tki - def custom_preprocess( self, model: ModelInterface, @@ -323,3 +271,35 @@ def custom_preprocess( asserts.assert_value_smaller_than_input_size( x=x_batch, value=self.k, value_name="k" ) + + def evaluate_batch( + self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + ) -> List[float]: + retval = [] + for a, s in zip(a_batch, s_batch): + if np.sum(s) == 0: + warn.warn_empty_segmentation() + retval.append(np.nan) + continue + + # Prepare shapes. + s = s.astype(bool) + top_k_binary_mask = np.zeros(a.shape) + + # Sort and create masks. + sorted_indices = np.argsort(a, axis=None) + np.put_along_axis( + top_k_binary_mask, sorted_indices[-self.k :], 1, axis=None + ) + top_k_binary_mask = top_k_binary_mask.astype(bool) + + # Top-k intersection. + tki = 1.0 / self.k * np.sum(np.logical_and(s, top_k_binary_mask)) + + # Concept influence (with size of object normalised tki score). + if self.concept_influence: + tki = np.prod(s.shape) / np.sum(s) * tki + + retval.append(tki) + + return retval diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 2ca48aad6..4991bf821 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -12,26 +12,25 @@ Dict, List, Optional, - Tuple, Union, Collection, - Iterable, ) + import numpy as np from tqdm.auto import tqdm -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.similarity_func import correlation_spearman -from quantus.metrics.base import Metric +from quantus.helpers import asserts +from quantus.helpers import warn from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric class ModelParameterRandomisation(Metric): @@ -276,6 +275,9 @@ def __call__( y_batch = data["y_batch"] a_batch = data["a_batch"] + if a_batch is None: + a_batch = self.explain_batch(model, x_batch, y_batch) + # Results are returned/saved as a dictionary not as a list as in the super-class. self.evaluation_scores = {} @@ -289,26 +291,19 @@ def __call__( ) for layer_name, random_layer_model in model_iterator: - similarity_scores = [None for _ in x_batch] # Generate an explanation with perturbed model. - a_batch_perturbed = self.explain_func( - model=random_layer_model, - inputs=x_batch, - targets=y_batch, - **self.explain_func_kwargs, + a_batch_perturbed = self.explain_batch( + random_layer_model, + x_batch, + y_batch, ) batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) for instance_id, (a_instance, a_instance_perturbed) in batch_iterator: - result = self.evaluate_instance( - model=random_layer_model, - x=None, - y=None, - s=None, - a=a_instance, - a_perturbed=a_instance_perturbed, + result = self.similarity_func( + a_instance_perturbed.flatten(), a_instance.flatten() ) similarity_scores[instance_id] = result @@ -338,49 +333,6 @@ def __call__( return self.evaluation_scores - def evaluate_instance( - self, - model: ModelInterface, - x: Optional[np.ndarray], - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - a_perturbed: Optional[np.ndarray] = None, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - i: integer - The evaluation instance. - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - a_perturbed: np.ndarray - The perturbed attributions. - - Returns - ------- - float - The evaluation results. - """ - if self.normalise: - a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) - - if self.abs: - a_perturbed = np.abs(a_perturbed) - - # Compute distance measure. - return self.similarity_func(a_perturbed.flatten(), a.flatten()) - def custom_preprocess( self, model: ModelInterface, @@ -419,7 +371,6 @@ def custom_preprocess( def compute_correlation_per_sample( self, ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - assert isinstance(self.evaluation_scores, dict), ( "To compute the average correlation coefficient per sample for " "Model Parameter Randomisation Test, 'last_result' " @@ -438,3 +389,6 @@ def compute_correlation_per_sample( corr_coeffs = list(results.values()) return corr_coeffs + + def evaluate_batch(self, *args, **kwargs): + raise RuntimeError("This is unexpected.") diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index d7183cd3d..adf02ec36 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -48,7 +48,6 @@ class RandomLogit(Metric): score_direction = ScoreDirection.LOWER evaluation_category = EvaluationCategory.RANDOMISATION - def __init__( self, similarity_func: Callable = None, @@ -243,7 +242,6 @@ def evaluate_instance( x: np.ndarray, y: np.ndarray, a: np.ndarray, - s: np.ndarray, ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -258,40 +256,12 @@ def evaluate_instance( The output to be evaluated on an instance-basis. a: np.ndarray The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. Returns ------- float The evaluation results. """ - # Randomly select off-class labels. - np.random.seed(self.seed) - y_off = np.array( - [ - np.random.choice( - [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] - ) - ] - ) - - # Explain against a random class. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=np.expand_dims(x, axis=0), - targets=y_off, - **self.explain_func_kwargs, - ) - - # Normalise and take absolute values of the attributions, if True. - if self.normalise: - a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) - - if self.abs: - a_perturbed = np.abs(a_perturbed) - - return self.similarity_func(a.flatten(), a_perturbed.flatten()) def custom_preprocess( self, @@ -327,3 +297,34 @@ def custom_preprocess( # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **_, + ) -> List[float]: + # TODO: vectorize + retval = [] + for x, y, a in zip(x_batch, y_batch, a_batch): + # Randomly select off-class labels. + np.random.seed(self.seed) + y_off = np.array( + [ + np.random.choice( + [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] + ) + ] + ) + # Explain against a random class. + a_perturbed = self.explain_batch( + model.get_model(), + np.expand_dims(x, axis=0), + y_off, + ) + retval.append(self.similarity_func(a.flatten(), a_perturbed.flatten())) + + return retval diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 5acc258e0..bae6c813b 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -16,7 +16,7 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import uniform_noise, perturb_batch from quantus.functions.similarity_func import difference -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.enums import ( ModelType, DataType, @@ -25,7 +25,7 @@ ) -class AvgSensitivity(BatchedPerturbationMetric): +class AvgSensitivity(PerturbationMetric): """ Implementation of Avg-Sensitivity by Yeh at el., 2019. @@ -52,7 +52,6 @@ class AvgSensitivity(BatchedPerturbationMetric): score_direction = ScoreDirection.LOWER evaluation_category = EvaluationCategory.ROBUSTNESS - def __init__( self, similarity_func: Optional[Callable] = None, @@ -293,7 +292,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, + **kwargs, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. @@ -320,7 +319,6 @@ def evaluate_batch( similarities = np.zeros((batch_size, self.nr_samples)) * np.nan for step_id in range(self.nr_samples): - # Perturb input. x_perturbed = perturb_batch( perturb_func=self.perturb_func, @@ -371,7 +369,6 @@ def evaluate_batch( # Measure similarity for each instance separately. for instance_id in range(batch_size): - if ( self.return_nan_when_prediction_changes and instance_id in changed_prediction_indices diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index da5b4c845..f834a672a 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -6,13 +6,11 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, no_type_check import numpy as np -from quantus.helpers import asserts from quantus.helpers import warn from quantus.functions.discretise_func import top_n_sign -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.metrics.base import Metric from quantus.helpers.enums import ( @@ -234,97 +232,52 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - i: int = None, - a_label: np.ndarray = None, - y_pred_classes: np.ndarray = None, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - i: int - The index of the current instance. - a_label: np.ndarray - The discretised attribution labels. - y_pred_classes: np,ndarray - The class predictions of the complete input dataset. - - Returns - ------- - float - The evaluation results. - """ - # Metric logic. - pred_a = y_pred_classes[i] - same_a = np.argwhere(a == a_label).flatten() - diff_a = same_a[same_a != i] - pred_same_a = y_pred_classes[diff_a] - - if len(same_a) == 0: - return 0 - return np.sum(pred_same_a == pred_a) / len(diff_a) + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + data_batch = super().batch_preprocess(data_batch) - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> Dict[str, Any]: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + model = data_batch["model"] + x_batch = data_batch["x_batch"] + x_input = model.shape_input( + x_batch, x_batch[0].shape, channel_first=True, batched=True + ) - Returns - ------- - dictionary[str, np.ndarray] - Output dictionary with 'a_label_batch' as key and discretised attributtion labels as value. - """ - # Preprocessing. + a_batch = data_batch["a_batch"] a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) + a_labels = np.array(list(map(self.discretise_func, a_batch_flat))) - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() - return { + custom_batch = { "i_batch": np.arange(x_batch.shape[0]), "a_label_batch": a_labels, "y_pred_classes": y_pred_classes, } + + data_batch.update(custom_batch) + return data_batch + + @no_type_check + def evaluate_batch( + self, + *, + a_batch: np.ndarray, + i_batch: np.ndarray, + a_label_batch: np.ndarray, + y_pred_classes, + **_, + ) -> List[float]: + # TODO: vectorize + retval = [] + for a, i, a_label in zip(a_batch, i_batch, a_label_batch): + pred_a = y_pred_classes[i] + same_a = np.argwhere(a == a_label).flatten() + diff_a = same_a[same_a != i] + pred_same_a = y_pred_classes[diff_a] + + if len(same_a) == 0: + retval.append(0.0) + else: + retval.append(np.sum(pred_same_a == pred_a) / len(diff_a)) + + return retval diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index eb73af6c4..a531baa89 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -5,7 +5,7 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +from __future__ import annotations import itertools from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -283,13 +283,8 @@ def __call__( ) def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - s: np.ndarray, - ) -> Dict: + self, model: ModelInterface, x: np.ndarray, y: np.ndarray + ) -> Dict[str, int | float]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -301,10 +296,6 @@ def evaluate_instance( The input to be evaluated on an instance-basis. y: np.ndarray The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. Returns ------- @@ -463,3 +454,13 @@ def aggregated_score(self): for sample in self.evaluation_scores.keys() ] ) + + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + **_, + ) -> List[Dict[str, int]]: + return [self.evaluate_instance(model, x, y) for x, y in zip(x_batch, y_batch)] diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index ad4e602ca..2d0dbc9c9 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -6,7 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import numpy as np from quantus.helpers import asserts @@ -15,7 +15,7 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import gaussian_noise, perturb_batch from quantus.functions.similarity_func import lipschitz_constant, distance_euclidean -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.enums import ( ModelType, DataType, @@ -24,7 +24,7 @@ ) -class LocalLipschitzEstimate(BatchedPerturbationMetric): +class LocalLipschitzEstimate(PerturbationMetric): """ Implementation of the Local Lipschitz Estimate (or Stability) test by Alvarez-Melis et al., 2018a, 2018b. @@ -299,7 +299,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, + **_, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. @@ -314,8 +314,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - s_batch: np.ndarray - The segmentation to be evaluated on a batch-basis. Returns ------- @@ -327,7 +325,6 @@ def evaluate_batch( similarities = np.zeros((batch_size, self.nr_samples)) * np.nan for step_id in range(self.nr_samples): - # Perturb input. x_perturbed = perturb_batch( perturb_func=self.perturb_func, diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 5403de799..1ff4ffba0 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -16,7 +16,7 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import uniform_noise, perturb_batch from quantus.functions.similarity_func import difference -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.enums import ( ModelType, DataType, @@ -25,7 +25,7 @@ ) -class MaxSensitivity(BatchedPerturbationMetric): +class MaxSensitivity(PerturbationMetric): """ Implementation of Max-Sensitivity by Yeh at el., 2019. @@ -292,7 +292,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, + **_, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. @@ -307,8 +307,6 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. - s_batch: np.ndarray - The segmentation to be evaluated on an instance-basis. Returns ------- @@ -319,7 +317,6 @@ def evaluate_batch( similarities = np.zeros((batch_size, self.nr_samples)) * np.nan for step_id in range(self.nr_samples): - # Perturb input. x_perturbed = perturb_batch( perturb_func=self.perturb_func, diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 7e16c3042..a43a5bb85 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -16,7 +16,7 @@ from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch @@ -29,7 +29,7 @@ ) -class RelativeInputStability(BatchedPerturbationMetric): +class RelativeInputStability(PerturbationMetric): """ Relative Input Stability leverages the stability of an explanation with respect to the change in the input data. diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 6ed012de0..3f387e165 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -15,7 +15,7 @@ import torch from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch @@ -28,7 +28,7 @@ ) -class RelativeOutputStability(BatchedPerturbationMetric): +class RelativeOutputStability(PerturbationMetric): """ Relative Output Stability leverages the stability of an explanation with respect to the change in the output logits. @@ -337,7 +337,6 @@ def evaluate_batch( ros_batch = np.zeros(shape=[self._nr_samples, x_batch.shape[0]]) for index in range(self._nr_samples): - # Perturb input. x_perturbed = perturb_batch( perturb_func=self.perturb_func, diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index 9643b350c..204cabe0c 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -16,7 +16,7 @@ from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_batched import BatchedPerturbationMetric +from quantus.metrics.base_perturbed import PerturbationMetric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch @@ -29,7 +29,7 @@ ) -class RelativeRepresentationStability(BatchedPerturbationMetric): +class RelativeRepresentationStability(PerturbationMetric): """ Relative Representation Stability leverages the stability of an explanation with respect to the change in the output logits. diff --git a/tests/README.md b/tests/README.md index ed60d71f7..c48426a0e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -47,4 +47,9 @@ Run type checking using [mypy](https://github.com/python/mypy) ```shell python3 -m tox run -e type +``` + +You can run all testing environments in parallel using multiprocessing by running: +```shell +python3 -m tox run-parallel ``` \ No newline at end of file From 8576b4d2437b2eeb5542e802b8fe4deff135fab2 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sat, 12 Aug 2023 22:59:37 +0200 Subject: [PATCH 02/58] WIP --- quantus/helpers/constants.py | 8 +++++++- quantus/metrics/robustness/continuity.py | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/quantus/helpers/constants.py b/quantus/helpers/constants.py index 4112d0a21..46fe0edf2 100644 --- a/quantus/helpers/constants.py +++ b/quantus/helpers/constants.py @@ -7,13 +7,19 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import List, Dict, Final, Mapping, Type +import sys +from typing import List, Dict, Mapping, Type from quantus.functions.loss_func import * from quantus.functions.normalise_func import * from quantus.functions.perturb_func import * from quantus.functions.similarity_func import * from quantus.metrics import * +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + AVAILABLE_METRICS: Final[Mapping[str, Mapping[str, Type[Metric]]]] = { "Faithfulness": { diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index a531baa89..343608cf1 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -9,6 +9,7 @@ import itertools from typing import Any, Callable, Dict, List, Optional import numpy as np +from collections import defaultdict from quantus.helpers import asserts from quantus.helpers import utils @@ -302,7 +303,7 @@ def evaluate_instance( dict The evaluation results. """ - results: Dict[int, list] = {k: [] for k in range(self.nr_patches + 1)} + results = defaultdict(lambda: []) for step in range(self.nr_steps): @@ -388,7 +389,7 @@ def evaluate_instance( patch_sum = float(sum(a_perturbed_patch)) results[ix_patch].append(patch_sum) - return results + return dict(**results) def custom_preprocess( self, From bafbd93305fb4d5002f1c5e38a58fdba587229ee Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sat, 12 Aug 2023 23:47:02 +0200 Subject: [PATCH 03/58] WIP --- quantus/metrics/robustness/continuity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 343608cf1..eded81458 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -388,8 +388,8 @@ def evaluate_instance( # Sum attributions for patch. patch_sum = float(sum(a_perturbed_patch)) results[ix_patch].append(patch_sum) - - return dict(**results) + + return {k: v for k, v in results.items()} def custom_preprocess( self, From 764e4157405c8a75f6bee54d6eb7fa7b699053f3 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sun, 13 Aug 2023 01:05:46 +0200 Subject: [PATCH 04/58] WIP --- quantus/metrics/axiomatic/input_invariance.py | 7 +-- quantus/metrics/base.py | 39 ++++++++++++-- quantus/metrics/base_perturbed.py | 49 +++++++++++++++++ quantus/metrics/robustness/avg_sensitivity.py | 24 +-------- quantus/metrics/robustness/continuity.py | 37 ++----------- .../robustness/local_lipschitz_estimate.py | 33 +----------- quantus/metrics/robustness/max_sensitivity.py | 32 +----------- .../robustness/relative_input_stability.py | 43 ++------------- .../robustness/relative_output_stability.py | 48 ++--------------- .../relative_representation_stability.py | 52 ++----------------- 10 files changed, 110 insertions(+), 254 deletions(-) diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index f0816cf27..b022dd621 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -297,12 +297,7 @@ def evaluate_batch( ) # Generate explanation based on shifted input x. - a_shifted = self.explain_func( - model=shifted_model, - inputs=x_shifted, - targets=y_batch, - **self.explain_func_kwargs, - ) + a_shifted = self.explain_batch(shifted_model, x_shifted, y_batch) # Compute the invaraince. score = np.all( diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 7073bcb74..854ae3cde 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -802,11 +802,44 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: return data_batch def explain_batch( - self, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray + self, model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, ) -> np.ndarray: - """Compute explanations, normalize and take absolute (if was configured so during metric initialization.)""" - if hasattr(model, "get_model"): + """ + + Parameters + ------- + + model: + x_batch: + y_batch: + batched: + If set to false, will np.expand_dims inputs. + + + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + This method should primarily be used if you need to generate additional explanation during + in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. + + + It will do few things: + - call model.shape_input + - unwrap model + - call explain_func + - expand attribution channel + - (optionally) normalize a_batch + - (optionally) take np.abs of a_batch + """ + + if isinstance(model, ModelInterface): # Sometimes the model is our wrapper, but sometimes raw Keras/Torch model. + x_batch = model.shape_input( + x=x_batch, + shape=x_batch.shape, + channel_first=True, + batched=True, + ) model = model.get_model() a_batch = self.explain_func( diff --git a/quantus/metrics/base_perturbed.py b/quantus/metrics/base_perturbed.py index 3c161501d..b4e34bf32 100644 --- a/quantus/metrics/base_perturbed.py +++ b/quantus/metrics/base_perturbed.py @@ -11,8 +11,13 @@ Callable, Dict, Optional, + List ) +import warnings +import numpy as np + +from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric @@ -100,3 +105,47 @@ def __init__( if perturb_func_kwargs is None: perturb_func_kwargs = {} self.perturb_func_kwargs = perturb_func_kwargs + + def changed_prediction_indices( + self, + model: ModelInterface, + x_batch: np.ndarray, + x_perturbed: np.ndarray + ) -> List[int]: + + """ + Find indices in batch, for which predicted label has changed after applying perturbation. + If metric has no `return_nan_when_prediction_changes` attribute, or it is False, will return empty list. + + Parameters + ---------- + model: + x_batch: + Batch of original inputs provided by user. + x_perturbed: + Batch of inputs after applying perturbation. + + Returns + ------- + + changed_idx: + List of indices in batch, for which predicted label has changed afer. + + """ + + if hasattr(self, "return_nan_when_prediction_changes"): + attr_name = "return_nan_when_prediction_changes" + elif hasattr(self, "_return_nan_when_prediction_changes"): + attr_name = "_return_nan_when_prediction_changes" + else: + warnings.warn("Called changed_prediction_indices(), from a metric, " + "without `return_nan_when_prediction_changes` instance attribute, this is unexpected.") + return [] + + if not getattr(self, attr_name): + return [] + + labels_before = model.predict(x_batch).argmax(axis=-1) + labels_after = model.predict(x_perturbed).argmax(axis=-1) + changed_idx = np.reshape(np.argwhere(labels_before != labels_after), -1) + return changed_idx.tolist() diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index bae6c813b..beb11f1fc 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -337,13 +337,7 @@ def evaluate_batch( else [] ) - x_input = model.shape_input( - x=x_perturbed, - shape=x_batch.shape, - channel_first=True, - batched=True, - ) - + for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( x=x_instance, @@ -351,21 +345,7 @@ def evaluate_batch( ) # Generate explanation based on perturbed input x. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=x_input, - targets=y_batch, - **self.explain_func_kwargs, - ) - - if self.normalise: - a_perturbed = self.normalise_func( - a_perturbed, - **self.normalise_func_kwargs, - ) - - if self.abs: - a_perturbed = np.abs(a_perturbed) + a_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Measure similarity for each instance separately. for instance_id in range(batch_size): diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index eded81458..36636c122 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -318,31 +318,12 @@ def evaluate_instance( ) x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - prediction_changed = ( - model.predict(np.expand_dims(x, 0)).argmax(axis=-1)[0] - != model.predict(x_input).argmax(axis=-1)[0] - if self.return_nan_when_prediction_changes - else False - ) - - # Generate explanations on perturbed input. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=x_input, - targets=y, - **self.explain_func_kwargs, - ) + prediction_changed = len( + self.changed_prediction_indices(model, np.expand_dims(x, 0), x_input) + ) != 0 # Taking the first element, since a_perturbed will be expanded to a batch dimension # not expected by the current index management functions. - a_perturbed = utils.expand_attribution_channel(a_perturbed, x_input)[0] - - if self.normalise: - a_perturbed = self.normalise_func( - a_perturbed, **self.normalise_func_kwargs - ) - - if self.abs: - a_perturbed = np.abs(a_perturbed) + a_perturbed = self.explain_batch(model, x_input, np.expand_dims(y, 0))[0] # Store the prediction score as the last element of the sub_self.evaluation_scores dictionary. y_pred = float(model.predict(x_input)[:, y]) @@ -377,16 +358,8 @@ def evaluate_instance( # not expected by the current index management functions. # a_perturbed = utils.expand_attribution_channel(a_perturbed, x_input)[0] - if self.normalise: - a_perturbed_patch = self.normalise_func( - a_perturbed_patch.flatten(), **self.normalise_func_kwargs - ) - - if self.abs: - a_perturbed_patch = np.abs(a_perturbed_patch.flatten()) - # Sum attributions for patch. - patch_sum = float(sum(a_perturbed_patch)) + patch_sum = float(np.sum(a_perturbed_patch)) results[ix_patch].append(patch_sum) return {k: v for k, v in results.items()} diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index 2d0dbc9c9..fd90af8af 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -334,21 +334,7 @@ def evaluate_batch( **self.perturb_func_kwargs, ) - changed_prediction_indices = ( - np.argwhere( - model.predict(x_batch).argmax(axis=-1) - != model.predict(x_perturbed).argmax(axis=-1) - ).reshape(-1) - if self.return_nan_when_prediction_changes - else [] - ) - - x_input = model.shape_input( - x=x_perturbed, - shape=x_batch.shape, - channel_first=True, - batched=True, - ) + changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( @@ -357,22 +343,7 @@ def evaluate_batch( ) # Generate explanation based on perturbed input x. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=x_input, - targets=y_batch, - **self.explain_func_kwargs, - ) - - if self.normalise: - a_perturbed = self.normalise_func( - a_perturbed, - **self.normalise_func_kwargs, - ) - - if self.abs: - a_perturbed = np.abs(a_perturbed) - + a_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Measure similarity for each instance separately. for instance_id in range(batch_size): if ( diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 1ff4ffba0..4c42f0082 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -326,21 +326,7 @@ def evaluate_batch( **self.perturb_func_kwargs, ) - changed_prediction_indices = ( - np.argwhere( - model.predict(x_batch).argmax(axis=-1) - != model.predict(x_perturbed).argmax(axis=-1) - ).reshape(-1) - if self.return_nan_when_prediction_changes - else [] - ) - - x_input = model.shape_input( - x=x_perturbed, - shape=x_batch.shape, - channel_first=True, - batched=True, - ) + changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( @@ -349,21 +335,7 @@ def evaluate_batch( ) # Generate explanation based on perturbed input x. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=x_input, - targets=y_batch, - **self.explain_func_kwargs, - ) - - if self.normalise: - a_perturbed = self.normalise_func( - a_perturbed, - **self.normalise_func_kwargs, - ) - - if self.abs: - a_perturbed = np.abs(a_perturbed) + a_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Measure similarity for each instance separately. for instance_id in range(batch_size): diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index a43a5bb85..0a4271c19 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional, Callable, Dict, List import numpy as np -from functools import partial if TYPE_CHECKING: import tensorflow as tf @@ -20,7 +19,6 @@ from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch -from quantus.helpers.utils import expand_attribution_channel from quantus.helpers.enums import ( ModelType, DataType, @@ -262,33 +260,6 @@ def relative_input_stability_objective( denominator += (denominator == 0) * self._eps_min return nominator / denominator - def generate_normalised_explanations_batch( - self, x_batch: np.ndarray, y_batch: np.ndarray, explain_func: Callable - ) -> np.ndarray: - """ - Generate explanation, apply normalization and take absolute values if configured so during metric instantiation. - - Parameters - ---------- - x_batch: np.ndarray - 4D tensor representing batch of input images. - y_batch: np.ndarray - 1D tensor, representing predicted labels for the x_batch. - explain_func: callable - Function to generate explanations, takes only inputs,targets kwargs. - - Returns - ------- - a_batch: np.ndarray - A batch of explanations. - """ - a_batch = explain_func(inputs=x_batch, targets=y_batch) - if self.normalise: - a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) - if self.abs: - a_batch = np.abs(a_batch) - return expand_attribution_channel(a_batch, x_batch) - def evaluate_batch( self, model: ModelInterface, @@ -321,9 +292,6 @@ def evaluate_batch( """ batch_size = x_batch.shape[0] - _explain_func = partial( - self.explain_func, model=model.get_model(), **self.explain_func_kwargs - ) # Prepare output array. ris_batch = np.zeros(shape=[self._nr_samples, x_batch.shape[0]]) @@ -339,9 +307,7 @@ def evaluate_batch( ) # Generate explanations for perturbed input. - a_batch_perturbed = self.generate_normalised_explanations_batch( - x_perturbed, y_batch, _explain_func - ) + a_batch_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Compute maximization's objective. ris = self.relative_input_stability_objective( @@ -354,11 +320,8 @@ def evaluate_batch( continue # If perturbed input caused change in prediction, then it's RIS=nan. - predicted_y = model.predict(x_batch).argmax(axis=-1) - predicted_y_perturbed = model.predict(x_perturbed).argmax(axis=-1) - changed_prediction_indices = np.argwhere( - predicted_y != predicted_y_perturbed - ).reshape(-1) + + changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) if len(changed_prediction_indices) == 0: continue diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 3f387e165..9a9f62784 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -268,33 +268,6 @@ def relative_output_stability_objective( denominator += (denominator == 0) * self._eps_min # prevent division by 0 return nominator / denominator - def generate_normalised_explanations_batch( - self, x_batch: np.ndarray, y_batch: np.ndarray, explain_func: Callable - ) -> np.ndarray: - """ - Generate explanation, apply normalization and take absolute values if configured so during metric instantiation. - - Parameters - ---------- - x_batch: np.ndarray - 4D tensor representing batch of input images. - y_batch: np.ndarray - 1D tensor, representing predicted labels for the x_batch. - explain_func: callable - Function to generate explanations, takes only inputs,targets kwargs. - - Returns - ------- - a_batch: np.ndarray - A batch of explanations. - """ - a_batch = explain_func(inputs=x_batch, targets=y_batch) - if self.normalise: - a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) - if self.abs: - a_batch = np.abs(a_batch) - return expand_attribution_channel(a_batch, x_batch) - def evaluate_batch( self, model: ModelInterface, @@ -347,9 +320,7 @@ def evaluate_batch( ) # Generate explanations for perturbed input. - a_batch_perturbed = self.generate_normalised_explanations_batch( - x_perturbed, y_batch, _explain_func - ) + a_batch_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Execute forward pass on perturbed inputs. logits_perturbed = model.predict(x_perturbed) @@ -359,21 +330,12 @@ def evaluate_batch( ) ros_batch[index] = ros - # We're done with this sample if `return_nan_when_prediction_changes`==False. - if not self._return_nan_when_prediction_changes: - continue - + # If perturbed input caused change in prediction, then it's ROS=nan. - predicted_y = model.predict(x_batch).argmax(axis=-1) - predicted_y_perturbed = model.predict(x_perturbed).argmax(axis=-1) - changed_prediction_indices = np.argwhere( - predicted_y != predicted_y_perturbed - ).reshape(-1) - - if len(changed_prediction_indices) == 0: - continue + changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) - ros_batch[index, changed_prediction_indices] = np.nan + if len(changed_prediction_indices) != 0: + ros_batch[index, changed_prediction_indices] = np.nan # Compute ROS. result = np.max(ros_batch, axis=0) diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index 204cabe0c..e1e07b344 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -281,33 +281,6 @@ def relative_representation_stability_objective( denominator += (denominator == 0) * self._eps_min return nominator / denominator - def generate_normalised_explanations_batch( - self, x_batch: np.ndarray, y_batch: np.ndarray, explain_func: Callable - ) -> np.ndarray: - """ - Generate explanation, apply normalization and take absolute values if configured so during metric instantiation. - - Parameters - ---------- - x_batch: np.ndarray - 4D tensor representing batch of input images. - y_batch: np.ndarray - 1D tensor, representing predicted labels for the x_batch. - explain_func: callable - Function to generate explanations, takes only inputs,targets kwargs. - - Returns - ------- - a_batch: np.ndarray - A batch of explanations. - """ - a_batch = explain_func(inputs=x_batch, targets=y_batch) - if self.normalise: - a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) - if self.abs: - a_batch = np.abs(a_batch) - return expand_attribution_channel(a_batch, x_batch) - def evaluate_batch( self, model: ModelInterface, @@ -340,9 +313,6 @@ def evaluate_batch( """ batch_size = x_batch.shape[0] - _explain_func = partial( - self.explain_func, model=model.get_model(), **self.explain_func_kwargs - ) # Retrieve internal representation for provided inputs. internal_representations = model.get_hidden_representations( @@ -363,9 +333,7 @@ def evaluate_batch( ) # Generate explanations for perturbed input. - a_batch_perturbed = self.generate_normalised_explanations_batch( - x_perturbed, y_batch, _explain_func - ) + a_batch_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Retrieve internal representation for perturbed inputs. internal_representations_perturbed = model.get_hidden_representations( @@ -380,21 +348,11 @@ def evaluate_batch( a_batch_perturbed, ) rrs_batch[index] = rrs - - # We're done with this sample if `return_nan_when_prediction_changes`==False. - if not self._return_nan_when_prediction_changes: - continue - # If perturbed input caused change in prediction, then it's RRS=nan. - predicted_y = model.predict(x_batch).argmax(axis=-1) - predicted_y_perturbed = model.predict(x_perturbed).argmax(axis=-1) - changed_prediction_indices = np.argwhere( - predicted_y != predicted_y_perturbed - ).reshape(-1) - - if len(changed_prediction_indices) == 0: - continue - rrs_batch[index, changed_prediction_indices] = np.nan + changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) + + if len(changed_prediction_indices) != 0: + rrs_batch[index, changed_prediction_indices] = np.nan # Compute RRS. result = np.max(rrs_batch, axis=0) From e8b9f71e2811ec29d34f9bacc99980f82199ca22 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sun, 13 Aug 2023 15:49:40 +0200 Subject: [PATCH 05/58] WIP --- quantus/metrics/base.py | 21 ++++++------ quantus/metrics/base_perturbed.py | 32 +++++++------------ quantus/metrics/faithfulness/sufficiency.py | 2 +- quantus/metrics/localisation/focus.py | 2 +- quantus/metrics/localisation/pointing_game.py | 6 ++-- quantus/metrics/robustness/avg_sensitivity.py | 1 - quantus/metrics/robustness/consistency.py | 2 +- quantus/metrics/robustness/continuity.py | 17 ++++++---- .../robustness/local_lipschitz_estimate.py | 4 ++- quantus/metrics/robustness/max_sensitivity.py | 4 ++- .../robustness/relative_input_stability.py | 6 ++-- .../robustness/relative_output_stability.py | 5 +-- .../relative_representation_stability.py | 6 ++-- 13 files changed, 55 insertions(+), 53 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 854ae3cde..78ff2d789 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -802,27 +802,28 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: return data_batch def explain_batch( - self, model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, ) -> np.ndarray: """ - + Parameters ------- - + model: x_batch: y_batch: batched: If set to false, will np.expand_dims inputs. - - + + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation during in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. - - + + It will do few things: - call model.shape_input - unwrap model @@ -831,7 +832,7 @@ def explain_batch( - (optionally) normalize a_batch - (optionally) take np.abs of a_batch """ - + if isinstance(model, ModelInterface): # Sometimes the model is our wrapper, but sometimes raw Keras/Torch model. x_batch = model.shape_input( diff --git a/quantus/metrics/base_perturbed.py b/quantus/metrics/base_perturbed.py index b4e34bf32..70db9f732 100644 --- a/quantus/metrics/base_perturbed.py +++ b/quantus/metrics/base_perturbed.py @@ -6,13 +6,7 @@ # Quantus project URL: . from abc import ABC -from typing import ( - Any, - Callable, - Dict, - Optional, - List -) +from typing import Any, Callable, Dict, Optional, List import warnings import numpy as np @@ -105,18 +99,14 @@ def __init__( if perturb_func_kwargs is None: perturb_func_kwargs = {} self.perturb_func_kwargs = perturb_func_kwargs - + def changed_prediction_indices( - self, - model: ModelInterface, - x_batch: np.ndarray, - x_perturbed: np.ndarray + self, model: ModelInterface, x_batch: np.ndarray, x_perturbed: np.ndarray ) -> List[int]: - """ Find indices in batch, for which predicted label has changed after applying perturbation. If metric has no `return_nan_when_prediction_changes` attribute, or it is False, will return empty list. - + Parameters ---------- model: @@ -127,24 +117,26 @@ def changed_prediction_indices( Returns ------- - + changed_idx: List of indices in batch, for which predicted label has changed afer. """ - + if hasattr(self, "return_nan_when_prediction_changes"): attr_name = "return_nan_when_prediction_changes" elif hasattr(self, "_return_nan_when_prediction_changes"): attr_name = "_return_nan_when_prediction_changes" else: - warnings.warn("Called changed_prediction_indices(), from a metric, " - "without `return_nan_when_prediction_changes` instance attribute, this is unexpected.") + warnings.warn( + "Called changed_prediction_indices(), from a metric, " + "without `return_nan_when_prediction_changes` instance attribute, this is unexpected." + ) return [] - + if not getattr(self, attr_name): return [] - + labels_before = model.predict(x_batch).argmax(axis=-1) labels_after = model.predict(x_perturbed).argmax(axis=-1) changed_idx = np.reshape(np.argwhere(labels_before != labels_after), -1) diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 0820ae974..b658fdeef 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -269,7 +269,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch.update(custom_batch) return data_batch - + @no_type_check def evaluate_batch( self, *, i_batch, a_sim_vector_batch, y_pred_classes, **_ diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 5f71f5ef0..0d2857eec 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -335,7 +335,7 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: :, int(self.mosaic_shape[1] / 2) :, int(self.mosaic_shape[2] / 2) : ] return quandrant_a - + @no_type_check def evaluate_batch( self, diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 1deb3f696..4fe20f625 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -268,8 +268,6 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - - retval = [] for a, s in zip(a_batch, s_batch): if np.sum(s) == 0: @@ -291,5 +289,5 @@ def evaluate_batch( hit = 1 - (np.sum(s) / float(np.prod(s.shape))) retval.append(hit) - - return retval \ No newline at end of file + + return retval diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index beb11f1fc..675e98f81 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -337,7 +337,6 @@ def evaluate_batch( else [] ) - for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( x=x_instance, diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index f834a672a..bd459f8df 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -256,7 +256,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch.update(custom_batch) return data_batch - + @no_type_check def evaluate_batch( self, diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 36636c122..be4184da5 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -9,7 +9,6 @@ import itertools from typing import Any, Callable, Dict, List, Optional import numpy as np -from collections import defaultdict from quantus.helpers import asserts from quantus.helpers import utils @@ -303,10 +302,9 @@ def evaluate_instance( dict The evaluation results. """ - results = defaultdict(lambda: []) + results: Dict[int, list] = {k: [] for k in range(self.nr_patches + 1)} for step in range(self.nr_steps): - # Generate explanation based on perturbed input x. dx_step = (step + 1) * self.dx x_perturbed = self.perturb_func( @@ -318,9 +316,14 @@ def evaluate_instance( ) x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - prediction_changed = len( - self.changed_prediction_indices(model, np.expand_dims(x, 0), x_input) - ) != 0 + prediction_changed = ( + len( + self.changed_prediction_indices( + model, np.expand_dims(x, 0), x_input + ) + ) + != 0 + ) # Taking the first element, since a_perturbed will be expanded to a batch dimension # not expected by the current index management functions. a_perturbed = self.explain_batch(model, x_input, np.expand_dims(y, 0))[0] @@ -361,7 +364,7 @@ def evaluate_instance( # Sum attributions for patch. patch_sum = float(np.sum(a_perturbed_patch)) results[ix_patch].append(patch_sum) - + return {k: v for k, v in results.items()} def custom_preprocess( diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index fd90af8af..6c722fb09 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -334,7 +334,9 @@ def evaluate_batch( **self.perturb_func_kwargs, ) - changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed + ) for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 4c42f0082..9fa8935b9 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -326,7 +326,9 @@ def evaluate_batch( **self.perturb_func_kwargs, ) - changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed + ) for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): warn.warn_perturbation_caused_no_change( diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 0a4271c19..4cb73cff6 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -320,8 +320,10 @@ def evaluate_batch( continue # If perturbed input caused change in prediction, then it's RIS=nan. - - changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) + + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed + ) if len(changed_prediction_indices) == 0: continue diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 9a9f62784..5cb2913db 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -330,9 +330,10 @@ def evaluate_batch( ) ros_batch[index] = ros - # If perturbed input caused change in prediction, then it's ROS=nan. - changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed + ) if len(changed_prediction_indices) != 0: ros_batch[index, changed_prediction_indices] = np.nan diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index e1e07b344..b94ee5ba0 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -349,8 +349,10 @@ def evaluate_batch( ) rrs_batch[index] = rrs # If perturbed input caused change in prediction, then it's RRS=nan. - changed_prediction_indices = self.changed_prediction_indices(model, x_batch, x_perturbed) - + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed + ) + if len(changed_prediction_indices) != 0: rrs_batch[index, changed_prediction_indices] = np.nan From ae8e079413b1d467e8e3f7a534b3c0459661e023 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sun, 13 Aug 2023 16:24:22 +0200 Subject: [PATCH 06/58] WIP --- quantus/metrics/axiomatic/completeness.py | 35 +++- quantus/metrics/axiomatic/non_sensitivity.py | 33 +++- quantus/metrics/base.py | 30 ++- quantus/metrics/complexity/complexity.py | 30 ++- .../complexity/effective_complexity.py | 25 ++- quantus/metrics/complexity/sparseness.py | 29 ++- .../faithfulness/faithfulness_correlation.py | 30 ++- .../faithfulness/faithfulness_estimate.py | 30 ++- quantus/metrics/faithfulness/infidelity.py | 185 ++++++++++-------- quantus/metrics/faithfulness/irof.py | 131 ++++++++----- 10 files changed, 383 insertions(+), 175 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 4c43bb3b1..5be750fac 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -270,8 +270,37 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[bool]: + + """ + + Checks if sum of attributions is equal to the difference between original prediction and + prediction on baseline value. + + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused kwargs. + + Returns + ------- + + scores_batch: + List of booleans. + + + """ + # TODO: vectorize - retval = [] + scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): x_baseline = self.perturb_func( arr=x, @@ -288,5 +317,5 @@ def evaluate_batch( x_input = model.shape_input(x_baseline, x.shape, channel_first=True) y_pred_baseline = float(model.predict(x_input)[:, y]) - retval.append(np.sum(a) == self.output_func(y_pred - y_pred_baseline)) - return retval + scores_batch.append(np.sum(a) == self.output_func(y_pred - y_pred_baseline)) + return scores_batch diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 09b2e70d6..529a0f327 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -311,7 +311,34 @@ def evaluate_batch( a_batch: np.ndarray, **kwargs, ) -> List[int]: - retval = [] + + """ + + Count the number of features in each explanation, for which model is not sensitive. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + kwargs: + Unused. + + Returns + ------- + + scores_batch: + List of integers. + + """ + + scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): a = a.flatten() @@ -346,6 +373,6 @@ def evaluate_batch( vars.append(np.var(preds)) non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) - retval.append(len(non_features_vars.symmetric_difference(non_features))) + scores_batch.append(len(non_features_vars.symmetric_difference(non_features))) - return retval + return scores_batch diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 78ff2d789..adf8e0f72 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -403,12 +403,7 @@ def general_preprocess( # Normalise with specified keyword arguments if requested. if self.normalise: - # TODO: what is this signature? - a_batch = self.normalise_func( - a_batch, - normalise_axes=list(range(np.ndim(a_batch)))[1:], - **self.normalise_func_kwargs, - ) + a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) # Take absolute if requested. if self.abs: @@ -808,6 +803,10 @@ def explain_batch( y_batch: np.ndarray, ) -> np.ndarray: """ + + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + This method should primarily be used if you need to generate additional explanation + in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. Parameters ------- @@ -815,22 +814,15 @@ def explain_batch( model: x_batch: y_batch: - batched: - If set to false, will np.expand_dims inputs. - - - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) - This method should primarily be used if you need to generate additional explanation during - in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - - call model.shape_input - - unwrap model - - call explain_func - - expand attribution channel - - (optionally) normalize a_batch - - (optionally) take np.abs of a_batch + - call model.shape_input + - unwrap model + - call explain_func + - expand attribution channel + - (optionally) normalize a_batch + - (optionally) take np.abs of a_batch """ if isinstance(model, ModelInterface): diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 1d2e12bb9..da6430e60 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -225,11 +225,29 @@ def __call__( **kwargs, ) - def evaluate_batch( - self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ - ) -> List[float]: + def evaluate_batch(self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_) -> List[float]: + """ + + TODO: what does it compute? + + Parameters + ---------- + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ # TODO: vectorize - retval = [] + scores_batch = [] for x, a in zip(x_batch, a_batch): if len(x.shape) == 1: newshape = np.prod(x.shape) @@ -237,6 +255,6 @@ def evaluate_batch( newshape = np.prod(x.shape[1:]) a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a)) - retval.append(scipy.stats.entropy(pk=a)) + scores_batch.append(scipy.stats.entropy(pk=a)) - return retval + return scores_batch diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 6eee89e7a..f4f739762 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -229,11 +229,30 @@ def __call__( ) def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: + """ + + Count how many attributions exceed the threshold `eps` + + Parameters + ---------- + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + _: + Unused + + Returns + ------- + + scores_batch: + List of integers. + + """ # TODO: vectorize - retval = [] + scores_batch = [] for a in a_batch: a = a.flatten() - retval.append(int(np.sum(a > self.eps))) + scores_batch.append(int(np.sum(a > self.eps))) - return retval + return scores_batch diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 3cac9324c..35f6d9a76 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -233,7 +233,30 @@ def __call__( def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ ) -> List[float]: - retval = [] + + """ + + TODO: what does it compute? + + Parameters + ---------- + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ + + scores_batch = [] + # TODO: vectorize for x, a in zip(x_batch, a_batch): if len(x.shape) == 1: @@ -247,6 +270,6 @@ def evaluate_batch( score = ( np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a) ) / (a.shape[0] * np.sum(a)) - retval.append(score) + scores_batch.append(score) - return retval + return scores_batch diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 79c3a2b7c..325f540c9 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -323,7 +323,31 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: - retval = [] + """ + + TODO: what does it compute? + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ + scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): a = a.flatten() @@ -356,6 +380,6 @@ def evaluate_batch( similarity = self.similarity_func(a=att_sums, b=pred_deltas) - retval.append(similarity) + scores_batch.append(similarity) - return retval + return scores_batch diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index a4c743902..5014c59d5 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -309,7 +309,31 @@ def evaluate_batch( a_batch: np.ndarray, **_, ): - rerval = [] + """ + + TODO: what does it compute? + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ + scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): # Flatten the attributions. @@ -350,6 +374,6 @@ def evaluate_batch( att_sums[i_ix] = np.sum(a[a_ix]) similarity = self.similarity_func(a=att_sums, b=pred_deltas) - rerval.append(similarity) + scores_batch.append(similarity) - return rerval + return scores_batch diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 43cd5010d..e7fe061b9 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -281,6 +281,105 @@ def __call__( device=device, **kwargs, ) + + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + results = [] + + for _ in range(self.n_perturb_samples): + + sub_results = [] + + for patch_size in self.perturb_patch_sizes: + + pred_deltas = np.zeros( + (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) + ) + a_sums = np.zeros( + (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) + ) + x_perturbed = x.copy() + pad_width = patch_size - 1 + + for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): + + for i_y, top_left_y in enumerate(range(0, x.shape[2], patch_size)): + + # Perturb input patch-wise. + x_perturbed_pad = utils._pad_array( + x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes + ) + patch_slice = utils.create_patch_slice( + patch_size=patch_size, + coords=[top_left_x, top_left_y], + ) + + x_perturbed_pad = self.perturb_func( + arr=x_perturbed_pad, + indices=patch_slice, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Remove padding. + x_perturbed = utils._unpad_array( + x_perturbed_pad, pad_width, padded_axes=self.a_axes + ) + + # Predict on perturbed input x_perturbed. + x_input = model.shape_input( + x_perturbed, x.shape, channel_first=True + ) + warn.warn_perturbation_caused_no_change( + x=x, x_perturbed=x_input + ) + y_pred_perturb = float(model.predict(x_input)[:, y]) + + x_diff = x - x_perturbed + a_diff = np.dot( + np.repeat(a, repeats=self.nr_channels, axis=0), x_diff + ) + + pred_deltas[i_x][i_y] = y_pred - y_pred_perturb + a_sums[i_x][i_y] = np.sum(a_diff) + + assert callable(self.loss_func) + sub_results.append( + self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) + ) + + results.append(np.mean(sub_results)) + + return np.mean(results) def custom_preprocess( self, @@ -323,83 +422,9 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **kwargs, - ) -> List[List[float]]: - # TODO: can we get rid of any for-loop? - retval = [] - - for x, y, a in zip(x_batch, y_batch, a_batch): - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - results = [] - - for _ in range(self.n_perturb_samples): - sub_results = [] - - for patch_size in self.perturb_patch_sizes: - pred_deltas = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - a_sums = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - x_perturbed = x.copy() - pad_width = patch_size - 1 - - for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): - for i_y, top_left_y in enumerate( - range(0, x.shape[2], patch_size) - ): - # Perturb input patch-wise. - x_perturbed_pad = utils._pad_array( - x_perturbed, - pad_width, - mode="edge", - padded_axes=self.a_axes, - ) - patch_slice = utils.create_patch_slice( - patch_size=patch_size, - coords=[top_left_x, top_left_y], - ) - - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - # Predict on perturbed input x_perturbed. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - warn.warn_perturbation_caused_no_change( - x=x, x_perturbed=x_input - ) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - x_diff = x - x_perturbed - a_diff = np.dot( - np.repeat(a, repeats=self.nr_channels, axis=0), x_diff - ) - - pred_deltas[i_x][i_y] = y_pred - y_pred_perturb - a_sums[i_x][i_y] = np.sum(a_diff) - - assert callable(self.loss_func) - sub_results.append( - self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) - ) - - results.append(np.mean(sub_results)) - - retval.append(results) - - return retval + **_, + ) -> List[float]: + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] \ No newline at end of file diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 73916e596..f36b504de 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -261,6 +261,81 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) + + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + Returns + ------- + float + The evaluation results. + """ + # Predict on x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + # Segment image. + segments = utils.get_superpixel_segments( + img=np.moveaxis(x, 0, -1).astype("double"), + segmentation_method=self.segmentation_method, + ) + nr_segments = len(np.unique(segments)) + asserts.assert_nr_segments(nr_segments=nr_segments) + + # Calculate average attribution of each segment. + att_segs = np.zeros(nr_segments) + for i, s in enumerate(range(nr_segments)): + att_segs[i] = np.mean(a[:, segments == s]) + + # Sort segments based on the mean attribution (descending order). + s_indices = np.argsort(-att_segs) + + preds = [] + x_prev_perturbed = x + + for i_ix, s_ix in enumerate(s_indices): + + # Perturb input by indices of attributions. + a_ix = np.nonzero((segments == s_ix).flatten())[0] + + x_perturbed = self.perturb_func( + arr=x_prev_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change( + x=x_prev_perturbed, x_perturbed=x_perturbed + ) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + + # Normalise the scores to be within range [0, 1]. + preds.append(float(y_pred_perturb / y_pred)) + x_prev_perturbed = x_perturbed + + # Calculate the area over the curve (AOC) score. + aoc = len(preds) - utils.calculate_auc(np.array(preds)) + return aoc def custom_preprocess( self, @@ -310,55 +385,7 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - # Segment image. - segments = utils.get_superpixel_segments( - img=np.moveaxis(x, 0, -1).astype("double"), - segmentation_method=self.segmentation_method, - ) - nr_segments = len(np.unique(segments)) - asserts.assert_nr_segments(nr_segments=nr_segments) - - # Calculate average attribution of each segment. - att_segs = np.zeros(nr_segments) - for i, s in enumerate(range(nr_segments)): - att_segs[i] = np.mean(a[:, segments == s]) - - # Sort segments based on the mean attribution (descending order). - s_indices = np.argsort(-att_segs) - - preds = [] - x_prev_perturbed = x - - for i_ix, s_ix in enumerate(s_indices): - # Perturb input by indices of attributions. - a_ix = np.nonzero((segments == s_ix).flatten())[0] - - x_perturbed = self.perturb_func( - arr=x_prev_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change( - x=x_prev_perturbed, x_perturbed=x_perturbed - ) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - # Normalise the scores to be within range [0, 1]. - preds.append(float(y_pred_perturb / y_pred)) - x_prev_perturbed = x_perturbed - - # Calculate the area over the curve (AOC) score. - aoc = len(preds) - utils.calculate_auc(np.array(preds)) - retval.append(aoc) - - return retval + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] From c16ff5159926e9d3d099c4e72dfedacaa563e5b8 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sun, 13 Aug 2023 16:30:58 +0200 Subject: [PATCH 07/58] WIP --- quantus/metrics/axiomatic/completeness.py | 11 +++++------ quantus/metrics/axiomatic/non_sensitivity.py | 15 ++++++++------- quantus/metrics/base.py | 2 +- quantus/metrics/base_batched.py | 6 ++++-- quantus/metrics/complexity/complexity.py | 10 ++++++---- .../metrics/complexity/effective_complexity.py | 8 ++++---- quantus/metrics/complexity/sparseness.py | 9 ++++----- .../faithfulness/faithfulness_correlation.py | 6 +++--- .../metrics/faithfulness/faithfulness_estimate.py | 6 +++--- quantus/metrics/faithfulness/infidelity.py | 8 ++------ quantus/metrics/faithfulness/irof.py | 3 +-- 11 files changed, 41 insertions(+), 43 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 5be750fac..21faf394e 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -270,13 +270,12 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[bool]: - """ - + Checks if sum of attributions is equal to the difference between original prediction and prediction on baseline value. - + Parameters ---------- model: ModelInterface @@ -292,13 +291,13 @@ def evaluate_batch( Returns ------- - + scores_batch: List of booleans. - + """ - + # TODO: vectorize scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 529a0f327..14d77c3b0 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -311,11 +311,10 @@ def evaluate_batch( a_batch: np.ndarray, **kwargs, ) -> List[int]: - """ - + Count the number of features in each explanation, for which model is not sensitive. - + Parameters ---------- model: ModelInterface @@ -326,18 +325,18 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - + kwargs: Unused. Returns ------- - + scores_batch: List of integers. """ - + scores_batch = [] for x, y, a in zip(x_batch, y_batch, a_batch): @@ -373,6 +372,8 @@ def evaluate_batch( vars.append(np.var(preds)) non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) - scores_batch.append(len(non_features_vars.symmetric_difference(non_features))) + scores_batch.append( + len(non_features_vars.symmetric_difference(non_features)) + ) return scores_batch diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index adf8e0f72..2148824ef 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -803,7 +803,7 @@ def explain_batch( y_batch: np.ndarray, ) -> np.ndarray: """ - + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index a1f2d1885..7ce4fda67 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -18,17 +18,19 @@ class BatchedMetric(Metric, abc.ABC): """Alias to quantus.Metric, will be removed in next major release.""" - def __subclasscheck__(self, subclass): + def __new__(cls, *args, **kwargs): warnings.warn( "BatchedMetric was deprecated, since it is just an alias to Metric. Please subclass Metric directly." ) + super().__new__(*args, **kwargs) class BatchedPerturbationMetric(PerturbationMetric, abc.ABC): """Alias to quantus.PerturbationMetric, will be removed in next major release.""" - def __subclasscheck__(self, subclass): + def __new__(cls, *args, **kwargs): warnings.warn( "BatchedPerturbationMetric was deprecated, " "since it is just an alias to Metric. Please subclass PerturbationMetric directly." ) + super().__new__(*args, **kwargs) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index da6430e60..369b6d1b4 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -225,11 +225,13 @@ def __call__( **kwargs, ) - def evaluate_batch(self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_) -> List[float]: + def evaluate_batch( + self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ + ) -> List[float]: """ - + TODO: what does it compute? - + Parameters ---------- x_batch: np.ndarray @@ -241,7 +243,7 @@ def evaluate_batch(self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_) -> Li Returns ------- - + scores_batch: List of floats. diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index f4f739762..810319f45 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -230,20 +230,20 @@ def __call__( def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: """ - + Count how many attributions exceed the threshold `eps` - + Parameters ---------- a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - + _: Unused Returns ------- - + scores_batch: List of integers. diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 35f6d9a76..cb86b2090 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -233,11 +233,10 @@ def __call__( def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ ) -> List[float]: - """ - + TODO: what does it compute? - + Parameters ---------- x_batch: np.ndarray @@ -249,12 +248,12 @@ def evaluate_batch( Returns ------- - + scores_batch: List of floats. """ - + scores_batch = [] # TODO: vectorize diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 325f540c9..d4687e285 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -324,9 +324,9 @@ def evaluate_batch( **_, ) -> List[float]: """ - + TODO: what does it compute? - + Parameters ---------- model: ModelInterface @@ -342,7 +342,7 @@ def evaluate_batch( Returns ------- - + scores_batch: List of floats. diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 5014c59d5..aa4387e89 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -310,9 +310,9 @@ def evaluate_batch( **_, ): """ - + TODO: what does it compute? - + Parameters ---------- model: ModelInterface @@ -328,7 +328,7 @@ def evaluate_batch( Returns ------- - + scores_batch: List of floats. diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index e7fe061b9..114ed2699 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -281,7 +281,7 @@ def __call__( device=device, **kwargs, ) - + def evaluate_instance( self, model: ModelInterface, @@ -316,11 +316,9 @@ def evaluate_instance( results = [] for _ in range(self.n_perturb_samples): - sub_results = [] for patch_size in self.perturb_patch_sizes: - pred_deltas = np.zeros( (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) ) @@ -331,9 +329,7 @@ def evaluate_instance( pad_width = patch_size - 1 for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): - for i_y, top_left_y in enumerate(range(0, x.shape[2], patch_size)): - # Perturb input patch-wise. x_perturbed_pad = utils._pad_array( x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes @@ -427,4 +423,4 @@ def evaluate_batch( return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) - ] \ No newline at end of file + ] diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index f36b504de..88ae5731e 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -261,7 +261,7 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) - + def evaluate_instance( self, model: ModelInterface, @@ -311,7 +311,6 @@ def evaluate_instance( x_prev_perturbed = x for i_ix, s_ix in enumerate(s_indices): - # Perturb input by indices of attributions. a_ix = np.nonzero((segments == s_ix).flatten())[0] From 61baaa0cca08246fb5ed1755df2d5f91fc16c5a1 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sun, 13 Aug 2023 16:45:59 +0200 Subject: [PATCH 08/58] WIP --- quantus/metrics/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 2148824ef..be079654e 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -779,6 +779,11 @@ def all_results(self): return self.all_evaluation_scores def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Does computationally heavy pre-processing on batch level to avoid OOM. + By default, will only generate explanations if missing, in case metric requires custom + data, it must be overriden in respective metric. + """ x_batch = data_batch["x_batch"] a_batch = data_batch.get("a_batch") From 121f0ef8abfa5eb7a27ed9f9cde370cdfa39eae2 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:04:50 +0200 Subject: [PATCH 09/58] minify git diff --- quantus/metrics/axiomatic/completeness.py | 88 ++++++++-------- quantus/metrics/axiomatic/non_sensitivity.py | 103 ++++++++++++------- 2 files changed, 109 insertions(+), 82 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 21faf394e..d57ac2b4f 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -261,60 +261,64 @@ def __call__( **kwargs, ) - def evaluate_batch( + def evaluate_instance( self, - *, model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **_, - ) -> List[bool]: + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> bool: """ - - Checks if sum of attributions is equal to the difference between original prediction and - prediction on baseline value. - + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. Parameters ---------- model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - _: - Unused kwargs. + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. Returns ------- - - scores_batch: - List of booleans. - - + : boolean + The evaluation results. """ + x_baseline = self.perturb_func( + arr=x, + indices=np.arange(0, x.size), + indexed_axes=np.arange(0, x.ndim), + **self.perturb_func_kwargs, + ) - # TODO: vectorize - scores_batch = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - x_baseline = self.perturb_func( - arr=x, - indices=np.arange(0, x.size), - indexed_axes=np.arange(0, x.ndim), - **self.perturb_func_kwargs, - ) + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) + # Predict on baseline. + x_input = model.shape_input(x_baseline, x.shape, channel_first=True) + y_pred_baseline = float(model.predict(x_input)[:, y]) - # Predict on baseline. - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_baseline = float(model.predict(x_input)[:, y]) + if np.sum(a) == self.output_func(y_pred - y_pred_baseline): + return True + else: + return False - scores_batch.append(np.sum(a) == self.output_func(y_pred - y_pred_baseline)) - return scores_batch + def evaluate_batch( + self, + *, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + s_batch: np.ndarray, + **kwargs, + ): + # TODO. For performance gains, replace the for loop below with vectorisation. + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 14d77c3b0..e96f2bfcf 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -265,6 +265,65 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> int: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + integer + The evaluation results. + """ + a = a.flatten() + + non_features = set(list(np.argwhere(a).flatten() < self.eps)) + + vars = [] + for i_ix, a_ix in enumerate(a[:: self.features_in_step]): + preds = [] + a_ix = a[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ].astype(int) + + for _ in range(self.n_samples): + # Perturb input by indices of attributions. + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturbed = float(model.predict(x_input)[:, y]) + + preds.append(y_pred_perturbed) + vars.append(np.var(preds)) + + non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) + + return len(non_features_vars.symmetric_difference(non_features)) + def custom_preprocess( self, model: ModelInterface, @@ -337,43 +396,7 @@ def evaluate_batch( """ - scores_batch = [] - - for x, y, a in zip(x_batch, y_batch, a_batch): - a = a.flatten() - - non_features = set(list(np.argwhere(a).flatten() < self.eps)) - - vars = [] - for i_ix, a_ix in enumerate(a[:: self.features_in_step]): - preds = [] - a_ix = a[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ].astype(int) - - for _ in range(self.n_samples): - # Perturb input by indices of attributions. - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - y_pred_perturbed = float(model.predict(x_input)[:, y]) - - preds.append(y_pred_perturbed) - vars.append(np.var(preds)) - - non_features_vars = set(list(np.argwhere(vars).flatten() < self.eps)) - scores_batch.append( - len(non_features_vars.symmetric_difference(non_features)) - ) - - return scores_batch + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] From 394d4179c503b7de5da2c76b01939207ecc4a1ea Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:10:55 +0200 Subject: [PATCH 10/58] minify git diff --- quantus/metrics/axiomatic/completeness.py | 32 ++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index d57ac2b4f..b7814d13b 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -314,9 +314,35 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, - **kwargs, - ): + **_, + ) -> List[bool]: + """ + + Checks if sum of attributions is equal to the difference between original prediction and + prediction on baseline value. + + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused kwargs. + + Returns + ------- + + scores_batch: + List of booleans. + + + """ # TODO. For performance gains, replace the for loop below with vectorisation. return [ self.evaluate_instance(model, x, y, a) From 41b8051ee826c27827045f7f162438e7da4080b8 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:19:59 +0200 Subject: [PATCH 11/58] minify git diff --- quantus/metrics/complexity/complexity.py | 45 ++++++++++----- .../complexity/effective_complexity.py | 30 +++++++--- quantus/metrics/complexity/sparseness.py | 55 ++++++++++++------- 3 files changed, 91 insertions(+), 39 deletions(-) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 369b6d1b4..a43370d11 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -225,12 +225,41 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + x: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + x: np.ndarray + The input to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a)) + return scipy.stats.entropy(pk=a) + def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ ) -> List[float]: """ - TODO: what does it compute? + TODO: write meaningful docstring about what does it compute. Parameters ---------- @@ -248,15 +277,5 @@ def evaluate_batch( List of floats. """ - # TODO: vectorize - scores_batch = [] - for x, a in zip(x_batch, a_batch): - if len(x.shape) == 1: - newshape = np.prod(x.shape) - else: - newshape = np.prod(x.shape[1:]) - - a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a)) - scores_batch.append(scipy.stats.entropy(pk=a)) - - return scores_batch + # TODO. For performance gains, replace the for loop below with vectorisation. + return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 810319f45..878f35ea1 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -228,6 +228,27 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + a: np.ndarray, + ) -> int: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + integer + The evaluation results. + """ + + a = a.flatten() + return int(np.sum(a > self.eps)) + def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: """ @@ -248,11 +269,6 @@ def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: List of integers. """ - # TODO: vectorize - scores_batch = [] - - for a in a_batch: - a = a.flatten() - scores_batch.append(int(np.sum(a > self.eps))) + # TODO. For performance gains, replace the for loop below with vectorisation. - return scores_batch + return [self.evaluate_instance(a) for a in a_batch] diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index cb86b2090..b683ad584 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -230,12 +230,44 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + x: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + x: np.ndarray + The input to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a = np.array(np.reshape(a, newshape), dtype=np.float64) + a += 0.0000001 + a = np.sort(a) + score = (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / ( + a.shape[0] * np.sum(a) + ) + return score + def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ ) -> List[float]: """ - - TODO: what does it compute? + TODO: write meaningful docstring about what does it compute. Parameters ---------- @@ -254,21 +286,6 @@ def evaluate_batch( """ - scores_batch = [] - # TODO: vectorize - - for x, a in zip(x_batch, a_batch): - if len(x.shape) == 1: - newshape = np.prod(x.shape) - else: - newshape = np.prod(x.shape[1:]) - - a = np.array(np.reshape(a, newshape), dtype=np.float64) - a += 0.0000001 - a = np.sort(a) - score = ( - np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a) - ) / (a.shape[0] * np.sum(a)) - scores_batch.append(score) + # TODO. For performance gains, replace the for loop below with vectorisation. - return scores_batch + return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] From 2dff4803c490048a5699aca730f2642b2438a2e6 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:37:59 +0200 Subject: [PATCH 12/58] minify git diff --- quantus/metrics/complexity/complexity.py | 9 +- quantus/metrics/complexity/sparseness.py | 9 +- .../faithfulness/faithfulness_correlation.py | 102 +++++++---- .../faithfulness/faithfulness_estimate.py | 118 ++++++++----- quantus/metrics/faithfulness/infidelity.py | 24 +++ quantus/metrics/faithfulness/irof.py | 23 +++ quantus/metrics/faithfulness/monotonicity.py | 131 ++++++++++----- .../faithfulness/monotonicity_correlation.py | 159 ++++++++++++------ .../metrics/faithfulness/pixel_flipping.py | 123 +++++++++----- .../faithfulness/region_perturbation.py | 19 +++ quantus/metrics/faithfulness/road.py | 107 +++++++++--- quantus/metrics/faithfulness/selectivity.py | 19 +++ quantus/metrics/faithfulness/sensitivity_n.py | 131 ++++++++++----- quantus/metrics/faithfulness/sufficiency.py | 56 ++++-- 14 files changed, 717 insertions(+), 313 deletions(-) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index a43370d11..73cd77ab4 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -224,12 +224,9 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) - - def evaluate_instance( - self, - x: np.ndarray, - a: np.ndarray, - ) -> float: + + @staticmethod + def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index b683ad584..eccd2c858 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -229,12 +229,9 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) - - def evaluate_instance( - self, - x: np.ndarray, - a: np.ndarray, - ) -> float: + + @staticmethod + def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index d4687e285..1d9c43f6c 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -276,6 +276,66 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + # Flatten the attributions. + a = a.flatten() + + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + pred_deltas = [] + att_sums = [] + + # For each test data point, execute a couple of runs. + for i_ix in range(self.nr_runs): + # Randomly mask by subset size. + a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas.append(float(y_pred - y_pred_perturb)) + + # Sum attributions of the random subset. + att_sums.append(np.sum(a[a_ix])) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas) + + return similarity + def custom_preprocess( self, model: ModelInterface, @@ -325,7 +385,7 @@ def evaluate_batch( ) -> List[float]: """ - TODO: what does it compute? + TODO: write meaningful docstring about what does it compute. Parameters ---------- @@ -347,39 +407,7 @@ def evaluate_batch( List of floats. """ - scores_batch = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - a = a.flatten() - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - pred_deltas = [] - att_sums = [] - - # For each test data point, execute a couple of runs. - for i_ix in range(self.nr_runs): - # Randomly mask by subset size. - a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(float(y_pred - y_pred_perturb)) - - # Sum attributions of the random subset. - att_sums.append(np.sum(a[a_ix])) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - - scores_batch.append(similarity) - - return scores_batch + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index aa4387e89..1a6b04d19 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -261,6 +261,71 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + # Flatten the attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Predict on input. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + pred_deltas = [None for _ in range(n_perturbations)] + att_sums = [None for _ in range(n_perturbations)] + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ] + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas[i_ix] = float(y_pred - y_pred_perturb) + + # Sum attributions. + att_sums[i_ix] = np.sum(a[a_ix]) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas) + return similarity + def custom_preprocess( self, model: ModelInterface, @@ -308,10 +373,10 @@ def evaluate_batch( y_batch: np.ndarray, a_batch: np.ndarray, **_, - ): + ) -> List[float]: """ - TODO: what does it compute? + TODO: write meaningful docstring about what does it compute. Parameters ---------- @@ -333,47 +398,8 @@ def evaluate_batch( List of floats. """ - scores_batch = [] - - for x, y, a in zip(x_batch, y_batch, a_batch): - # Flatten the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - pred_deltas = [None for _ in range(n_perturbations)] - att_sums = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ] - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas[i_ix] = float(y_pred - y_pred_perturb) - - # Sum attributions. - att_sums[i_ix] = np.sum(a[a_ix]) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - scores_batch.append(similarity) - - return scores_batch + + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 114ed2699..b801544ee 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -420,6 +420,30 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ + return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 88ae5731e..4a8d7f843 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -384,6 +384,29 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 7acfd318a..49937271e 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -258,6 +258,68 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + # Prepare shapes. + a = a.flatten() + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + preds = [None for _ in range(n_perturbations)] + + # Copy the input x but fill with baseline values. + baseline_value = utils.get_baseline_value( + value=self.perturb_func_kwargs["perturb_baseline"], + arr=x, + return_shape=x.shape, # TODO. Double-check this over using = (1,). + ) + x_baseline = np.full(x.shape, baseline_value) + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ] + x_baseline = self.perturb_func( + arr=x_baseline, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + + # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). + x_input = model.shape_input(x_baseline, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + preds[i_ix] = y_pred_perturb + + return np.all(np.diff(preds) >= 0) + def custom_preprocess( self, model: ModelInterface, @@ -304,44 +366,31 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[bool]: - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Prepare shapes. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - - # Copy the input x but fill with baseline values. - baseline_value = utils.get_baseline_value( - value=self.perturb_func_kwargs["perturb_baseline"], - arr=x, - return_shape=x.shape, # TODO. Double-check this over using = (1,). - ) - x_baseline = np.full(x.shape, baseline_value) - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ] - x_baseline = self.perturb_func( - arr=x_baseline, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - - # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - retval.append(np.all(np.diff(preds) >= 0)) - - return retval + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + _: + Unused. + + Returns + ------- + + scores_batch: + List of floats. + + """ + + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a, s in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index f52896055..73049d9f2 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -275,6 +275,78 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + # Predict on input x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) + inv_pred = inv_pred**2 + + # Reshape attributions. + a = a.flatten() + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a) + + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + atts = [None for _ in range(n_perturbations)] + vars = [None for _ in range(n_perturbations)] + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ] + + y_pred_perturbs = [] + + for s_ix in range(self.nr_samples): + x_perturbed = self.perturb_func( + arr=x, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + y_pred_perturbs.append(y_pred_perturb) + + vars[i_ix] = float( + np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) * inv_pred + ) + atts[i_ix] = float(sum(a[a_ix])) + + return self.similarity_func(a=atts, b=vars) + def custom_preprocess( self, model: ModelInterface, @@ -321,59 +393,38 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Predict on input x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred**2 - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - atts = [None for _ in range(n_perturbations)] - vars = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ] - - y_pred_perturbs = [] - - for s_ix in range(self.nr_samples): - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change( - x=x, x_perturbed=x_perturbed - ) - - # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - y_pred_perturb = float(model.predict(x_input)[:, y]) - y_pred_perturbs.append(y_pred_perturb) - - vars[i_ix] = float( - np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) - * inv_pred - ) - atts[i_ix] = float(sum(a[a_ix])) + """ + TODO: write meaningful docstring about what does it compute. - retval.append(self.similarity_func(a=atts, b=vars)) + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. - return retval + Returns + ------- + List[float] + The evaluation results. + """ + # Asserts. + asserts.assert_features_in_step( + features_in_step=self.features_in_step, + input_shape=x_batch.shape[2:], + ) + + # Evaluate explanations. + return [ + self.evaluate_instance( + model=model, + x=x, + y=y, + a=a, + ) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 4540591b9..8a962fb97 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -260,6 +260,67 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> List[float]: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + list + The evaluation results. + """ + + # Reshape attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Prepare lists. + n_perturbations = len(range(0, len(a_indices), self.features_in_step)) + preds = [None for _ in range(n_perturbations)] + x_perturbed = x.copy() + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ] + x_perturbed = self.perturb_func( + arr=x_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + preds[i_ix] = y_pred_perturb + + if self.return_auc_per_sample: + return utils.calculate_auc(preds) + + return preds + def custom_preprocess( self, model: ModelInterface, @@ -313,42 +374,26 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float | np.ndarray]: - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Prepare lists. - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - if self.return_auc_per_sample: - retval.append(utils.calculate_auc(preds)) - else: - retval.append(preds) - - return retval + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + Returns + ------- + list + The evaluation results. + """ + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 412e8a3ae..4d4d01aeb 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -430,6 +430,25 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[List[float]]: + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + Returns + ------- + list + The evaluation results. + """ return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index a5bcec557..4f1986588 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -289,6 +289,65 @@ def evaluate_instance( : list The evaluation results. """ + # Order indices. + ordered_indices = np.argsort(a, axis=None)[::-1] + + results_instance = np.array([None for _ in self.percentages]) + + for p_ix, p in enumerate(self.percentages): + top_k_indices = ordered_indices[: int(self.a_size * p / 100)] + + x_perturbed = self.perturb_func( + arr=x, + indices=top_k_indices, + **self.perturb_func_kwargs, + ) + + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Predict on perturbed input x and store the difference from predicting on unperturbed input. + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + class_pred_perturb = np.argmax(model.predict(x_input)) + + # Write a boolean into the percentage results. + results_instance[p_ix] = int(y == class_pred_perturb) + + # Return list of booleans for each percentage. + return results_instance + + def custom_preprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: np.ndarray, + custom_batch: Optional[np.ndarray], + ) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + custom_batch: any + Gives flexibility ot the user to use for evaluation, can hold any variable. + + Returns + ------- + None + """ + # Infer the size of attributions. + self.a_size = a_batch[0, :, :].size def custom_postprocess( self, @@ -341,32 +400,26 @@ def evaluate_batch( a_batch: np.ndarray, **kwargs, ): - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Order indices. - ordered_indices = np.argsort(a, axis=None)[::-1] - - results_instance = np.array([None for _ in self.percentages]) - - for p_ix, p in enumerate(self.percentages): - top_k_indices = ordered_indices[: int(self.a_size * p / 100)] - - x_perturbed = self.perturb_func( - arr=x, - indices=top_k_indices, - **self.perturb_func_kwargs, - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - class_pred_perturb = np.argmax(model.predict(x_input)) - - # Write a boolean into the percentage results. - results_instance[p_ix] = int(y == class_pred_perturb) + """ + TODO: write meaningful docstring about what does it compute. - # Return list of booleans for each percentage. - retval.append(results_instance) + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. - return retval + Returns + ------- + list + The evaluation results. + """ + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index d4d4a9ba1..8c4f96853 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -389,6 +389,25 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[List[float]]: + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + Returns + ------- + list + The evaluation results. + """ return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index df2685405..2d265ed46 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -273,6 +273,70 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + model: ModelInterface, + x: np.ndarray, + y: np.ndarray, + a: np.ndarray, + ) -> Dict[str, List[float]]: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + + Returns + ------- + (Dict[str, List[float]]): The evaluation results. + """ + + # Reshape the attributions. + a = a.flatten() + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a) + + # Predict on x. + x_input = model.shape_input(x, x.shape, channel_first=True) + y_pred = float(model.predict(x_input)[:, y]) + + att_sums = [] + pred_deltas = [] + x_perturbed = x.copy() + + for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): + # Perturb input by indices of attributions. + a_ix = a_indices[ + (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) + ] + x_perturbed = self.perturb_func( + arr=x_perturbed, + indices=a_ix, + indexed_axes=self.a_axes, + **self.perturb_func_kwargs, + ) + warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) + + # Sum attributions. + att_sums.append(float(a[a_ix].sum())) + + x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) + y_pred_perturb = float(model.predict(x_input)[:, y]) + pred_deltas.append(y_pred - y_pred_perturb) + + # Each list-element of self.evaluation_scores will be such a dictionary + # We will unpack that later in custom_postprocess(). + return {"att_sums": att_sums, "pred_deltas": pred_deltas} + def custom_preprocess( self, model: ModelInterface, @@ -378,46 +442,27 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[Dict[str, List[float]]]: - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Reshape the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - att_sums = [] - pred_deltas = [] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : ( - self.features_in_step * (i_ix + 1) - ) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - **self.perturb_func_kwargs, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Sum attributions. - att_sums.append(float(a[a_ix].sum())) - - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(y_pred - y_pred_perturb) - - # Each list-element of self.evaluation_scores will be such a dictionary - # We will unpack that later in custom_postprocess(). - retval.append({"att_sums": att_sums, "pred_deltas": pred_deltas}) - - return retval + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + + Returns + ------- + list + The evaluation results. + """ + + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index b658fdeef..0af6ff069 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -244,6 +244,40 @@ def __call__( **kwargs, ) + @staticmethod + def evaluate_instance( + i: int = None, + a_sim_vector: np.ndarray = None, + y_pred_classes: np.ndarray = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + i: int + The index of the current instance. + a_sim_vector: any + The custom input to be evaluated on an instance-basis. + y_pred_classes: np,ndarray + The class predictions of the complete input dataset. + + Returns + ------- + float + The evaluation results. + """ + + # Metric logic. + pred_a = y_pred_classes[i] + low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() + low_dist_a = low_dist_a[low_dist_a != i] + pred_low_dist_a = y_pred_classes[low_dist_a] + + if len(low_dist_a) == 0: + return 0 + return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch = super().batch_preprocess(data_batch) model = data_batch["model"] @@ -274,17 +308,11 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: def evaluate_batch( self, *, i_batch, a_sim_vector_batch, y_pred_classes, **_ ) -> List[float]: - retval = [] - for i, a_sim_vector in zip(i_batch, a_sim_vector_batch): - # Metric logic. - pred_a = y_pred_classes[i] - low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() - low_dist_a = low_dist_a[low_dist_a != i] - pred_low_dist_a = y_pred_classes[low_dist_a] - - if len(low_dist_a) == 0: - retval.append(0.0) - else: - retval.append((pred_low_dist_a == pred_a) / len(low_dist_a)) - - return retval + """ + TODO: write meaningful docstring about what does it compute. + """ + + return [ + self.evaluate_instance(i, a_sim_vector, y_pred_classes) + for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) + ] From e96889612d62ea03ad2c2e1888e09e5d28f7a869 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:52:39 +0200 Subject: [PATCH 13/58] minify git diff --- quantus/metrics/faithfulness/monotonicity.py | 2 +- quantus/metrics/faithfulness/road.py | 34 ------ .../localisation/attribution_localisation.py | 105 ++++++++++++------ quantus/metrics/localisation/auc.py | 50 ++++++--- quantus/metrics/localisation/focus.py | 84 ++++++++------ quantus/metrics/localisation/pointing_game.py | 65 +++++++---- .../localisation/relevance_mass_accuracy.py | 61 ++++++---- .../localisation/relevance_rank_accuracy.py | 77 ++++++++----- .../localisation/top_k_intersection.py | 72 +++++++----- 9 files changed, 328 insertions(+), 222 deletions(-) diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 49937271e..7dec97bce 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -392,5 +392,5 @@ def evaluate_batch( return [ self.evaluate_instance(model, x, y, a) - for x, y, a, s in zip(x_batch, y_batch, a_batch) + for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 4f1986588..c864eb484 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -315,40 +315,6 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - - Returns - ------- - None - """ - # Infer the size of attributions. - self.a_size = a_batch[0, :, :].size - def custom_postprocess( self, model: ModelInterface, diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index c1d0c8201..edb3de75a 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -240,6 +240,58 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + x: np.ndarray, + a: np.ndarray, + s: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + x: np.ndarray + The input to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Compute ratio. + size_bbox = float(np.sum(s)) + size_data = np.prod(x.shape[1:]) + ratio = size_bbox / size_data + + # Compute inside/outside ratio. + inside_attribution = np.sum(a[s]) + total_attribution = np.sum(a) + inside_attribution_ratio = float(inside_attribution / total_attribution) + + if not ratio <= self.max_size: + warn.warn_max_size() + if inside_attribution_ratio > 1.0: + warn.warn_segmentation(inside_attribution, total_attribution) + return np.nan + if not self.weighted: + return inside_attribution_ratio + else: + return float(inside_attribution_ratio * ratio) + def custom_preprocess( self, model: ModelInterface, @@ -277,35 +329,24 @@ def custom_preprocess( def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - for x, a, s in zip(x_batch, a_batch, s_batch): - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Compute ratio. - size_bbox = float(np.sum(s)) - size_data = np.prod(x.shape[1:]) - ratio = size_bbox / size_data - - # Compute inside/outside ratio. - inside_attribution = np.sum(a[s]) - total_attribution = np.sum(a) - inside_attribution_ratio = float(inside_attribution / total_attribution) - - if not ratio <= self.max_size: - warn.warn_max_size() - if inside_attribution_ratio > 1.0: - warn.warn_segmentation(inside_attribution, total_attribution) - retval.append(np.nan) - elif not self.weighted: - retval.append(inside_attribution_ratio) - else: - retval.append(float(inside_attribution_ratio * ratio)) - - return retval + """ + TODO: write meaningful docstring about what does it compute. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + a_batch: np.ndarray + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray + A np.ndarray which contains segmentation masks that matches the input. + + Returns + ------- + evaluation_scores: list + a list of floats. + """ + return [ + self.evaluate_instance(x, a, s) + for x, a, s in zip(x_batch, a_batch, s_batch) + ] diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index f32b882e3..8e2ffd240 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -218,6 +218,37 @@ def __call__( **kwargs, ) + @staticmethod + def evaluate_instance(a: np.ndarray, s: np.ndarray) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + : float + The evaluation results. + """ + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + fpr, tpr, _ = roc_curve(y_true=s, y_score=a) + score = auc(x=fpr, y=tpr) + + return score + def custom_preprocess( self, model: ModelInterface, @@ -255,20 +286,5 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - # TODO: vectorize - for a, s in zip(a_batch, s_batch): - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - fpr, tpr, _ = roc_curve(y_true=s, y_score=a) - score = auc(x=fpr, y=tpr) - retval.append(score) - - return retval + # TODO: for performance reasons replace for-loop with vectorized dispatch. + return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 0d2857eec..6437b98e4 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -262,6 +262,52 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + a: np.ndarray, + c: np.ndarray = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + c: any + The custom input to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + # Prepare shapes for mosaics. + self.mosaic_shape = a.shape + + total_positive_relevance = np.sum(a[a > 0], dtype=np.float64) + target_positive_relevance = 0 + + quadrant_functions_list = [ + self.quadrant_top_left, + self.quadrant_top_right, + self.quadrant_bottom_left, + self.quadrant_bottom_right, + ] + + for quadrant_p, quadrant_func in zip(c, quadrant_functions_list): + if not bool(quadrant_p): + continue + quadrant_relevance = quadrant_func(a=a) + target_positive_relevance += np.sum( + quadrant_relevance[quadrant_relevance > 0] + ) + + focus_score = target_positive_relevance / total_positive_relevance + + return focus_score + def custom_preprocess( self, model: ModelInterface, @@ -336,39 +382,7 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: ] return quandrant_a - @no_type_check def evaluate_batch( - self, - *, - a_batch: np.ndarray, - c_batch: np.ndarray, - **_, - ): - retval = [] - for a, c in zip(a_batch, c_batch): - # Prepare shapes for mosaics. - self.mosaic_shape = a.shape - - total_positive_relevance = np.sum(a[a > 0], dtype=np.float64) - target_positive_relevance = 0 - - quadrant_functions_list = [ - self.quadrant_top_left, - self.quadrant_top_right, - self.quadrant_bottom_left, - self.quadrant_bottom_right, - ] - - for quadrant_p, quadrant_func in zip(c, quadrant_functions_list): - if not bool(quadrant_p): - continue - quadrant_relevance = quadrant_func(a=a) - target_positive_relevance += np.sum( - quadrant_relevance[quadrant_relevance > 0] - ) - - focus_score = target_positive_relevance / total_positive_relevance - - retval.append(focus_score) - - return retval + self, *, a_batch: np.ndarray, c_batch: np.ndarray, **_ + ) -> List[float]: + return [self.evaluate_instance(a, c) for a, c in zip(a_batch, c_batch)] diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 4fe20f625..b39268ed8 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -230,6 +230,47 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + a: np.ndarray, + s: np.ndarray, + ) -> bool: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + boolean + The evaluation results. + """ + + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Find indices with max value. + max_index = np.argwhere(a == np.max(a)) + + # Check if maximum of explanation is on target object class. + hit = np.any(s[max_index]) + + if self.weighted and hit: + hit = 1 - (np.sum(s) / float(np.prod(s.shape))) + + return hit + def custom_preprocess( self, model: ModelInterface, @@ -268,26 +309,4 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - for a, s in zip(a_batch, s_batch): - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Find indices with max value. - max_index = np.argwhere(a == np.max(a)) - - # Check if maximum of explanation is on target object class. - hit = np.any(s[max_index]) - - if self.weighted and hit: - hit = 1 - (np.sum(s) / float(np.prod(s.shape))) - - retval.append(hit) - - return retval + return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 88cec2914..4016327e4 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -224,6 +224,44 @@ def __call__( **kwargs, ) + @staticmethod + def evaluate_instance( + a: np.ndarray, + s: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + a = a.flatten() + s = s.flatten().astype(bool) + + # Compute inside/outside ratio. + r_within = np.sum(a[s]) + r_total = np.sum(a) + + # Calculate mass accuracy. + mass_accuracy = r_within / r_total + + return mass_accuracy + def custom_preprocess( self, model: ModelInterface, @@ -261,25 +299,4 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - for a, s in zip(a_batch, s_batch): - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - a = a.flatten() - s = s.flatten().astype(bool) - - # Compute inside/outside ratio. - r_within = np.sum(a[s]) - r_total = np.sum(a) - - # Calculate mass accuracy. - mass_accuracy = r_within / r_total - - retval.append(mass_accuracy) - - return retval + return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 9d87b604f..c0323b425 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -226,6 +226,51 @@ def __call__( **kwargs, ) + @staticmethod + def evaluate_instance( + a: np.ndarray, + s: np.ndarray, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + # Return np.nan as result if segmentation map is empty. + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + a = a.flatten() + s = np.where(s.flatten().astype(bool))[0] + + # Size of the ground truth mask. + k = len(s) + + # Sort in descending order. + a_sorted = np.argsort(a)[-int(k) :] + + # Calculate hits. + hits = len(np.intersect1d(s, a_sorted)) + + if hits != 0: + rank_accuracy = hits / float(k) + else: + rank_accuracy = 0.0 + + return rank_accuracy + def custom_preprocess( self, model: ModelInterface, @@ -263,33 +308,5 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - - for a, s in zip(a_batch, s_batch): - # Return np.nan as result if segmentation map is empty. - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - a = a.flatten() - s = np.where(s.flatten().astype(bool))[0] - - # Size of the ground truth mask. - k = len(s) - - # Sort in descending order. - a_sorted = np.argsort(a)[-int(k) :] - - # Calculate hits. - hits = len(np.intersect1d(s, a_sorted)) - - if hits != 0: - rank_accuracy = hits / float(k) - else: - rank_accuracy = 0.0 - - retval.append(rank_accuracy) - - return retval + # TODO: for performance reasons, this method should be vectorized. + return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 91340f55b..f9af43d3b 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -235,6 +235,49 @@ def __call__( **kwargs, ) + def evaluate_instance( + self, + a: np.ndarray, + s: np.ndarray, + ): + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + float + The evaluation results. + """ + + if np.sum(s) == 0: + warn.warn_empty_segmentation() + return np.nan + + # Prepare shapes. + s = s.astype(bool) + top_k_binary_mask = np.zeros(a.shape) + + # Sort and create masks. + sorted_indices = np.argsort(a, axis=None) + np.put_along_axis(top_k_binary_mask, sorted_indices[-self.k :], 1, axis=None) + top_k_binary_mask = top_k_binary_mask.astype(bool) + + # Top-k intersection. + tki = 1.0 / self.k * np.sum(np.logical_and(s, top_k_binary_mask)) + + # Concept influence (with size of object normalised tki score). + if self.concept_influence: + tki = np.prod(s.shape) / np.sum(s) * tki + + return tki + def custom_preprocess( self, model: ModelInterface, @@ -275,31 +318,4 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: - retval = [] - for a, s in zip(a_batch, s_batch): - if np.sum(s) == 0: - warn.warn_empty_segmentation() - retval.append(np.nan) - continue - - # Prepare shapes. - s = s.astype(bool) - top_k_binary_mask = np.zeros(a.shape) - - # Sort and create masks. - sorted_indices = np.argsort(a, axis=None) - np.put_along_axis( - top_k_binary_mask, sorted_indices[-self.k :], 1, axis=None - ) - top_k_binary_mask = top_k_binary_mask.astype(bool) - - # Top-k intersection. - tki = 1.0 / self.k * np.sum(np.logical_and(s, top_k_binary_mask)) - - # Concept influence (with size of object normalised tki score). - if self.concept_influence: - tki = np.prod(s.shape) / np.sum(s) * tki - - retval.append(tki) - - return retval + return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] From 7acfd857defec57dd6b2f3cb7979266b2032bb38 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 16:59:07 +0200 Subject: [PATCH 14/58] minify git diff --- quantus/metrics/complexity/complexity.py | 2 +- quantus/metrics/complexity/sparseness.py | 2 +- quantus/metrics/randomisation/random_logit.py | 52 ++++++++++------- quantus/metrics/robustness/consistency.py | 57 ++++++++++++++----- 4 files changed, 75 insertions(+), 38 deletions(-) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 73cd77ab4..1d9a0b575 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -224,7 +224,7 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) - + @staticmethod def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: """ diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index eccd2c858..7de93fae8 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -229,7 +229,7 @@ def __call__( model_predict_kwargs=model_predict_kwargs, **kwargs, ) - + @staticmethod def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: """ diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index adf02ec36..3986fc554 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -262,6 +262,32 @@ def evaluate_instance( float The evaluation results. """ + # Randomly select off-class labels. + np.random.seed(self.seed) + y_off = np.array( + [ + np.random.choice( + [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] + ) + ] + ) + + # Explain against a random class. + a_perturbed = self.explain_func( + model=model.get_model(), + inputs=np.expand_dims(x, axis=0), + targets=y_off, + **self.explain_func_kwargs, + ) + + # Normalise and take absolute values of the attributions, if True. + if self.normalise: + a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) + + if self.abs: + a_perturbed = np.abs(a_perturbed) + + return self.similarity_func(a.flatten(), a_perturbed.flatten()) def custom_preprocess( self, @@ -307,24 +333,8 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: - # TODO: vectorize - retval = [] - for x, y, a in zip(x_batch, y_batch, a_batch): - # Randomly select off-class labels. - np.random.seed(self.seed) - y_off = np.array( - [ - np.random.choice( - [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] - ) - ] - ) - # Explain against a random class. - a_perturbed = self.explain_batch( - model.get_model(), - np.expand_dims(x, axis=0), - y_off, - ) - retval.append(self.similarity_func(a.flatten(), a_perturbed.flatten())) - - return retval + # TODO: for performance reasons vectorize this for-loop + return [ + self.evaluate_instance(model, x, y, a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index bd459f8df..b4f0ddb93 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -232,6 +232,42 @@ def __call__( **kwargs, ) + @staticmethod + def evaluate_instance( + a: np.ndarray, + i: int = None, + a_label: np.ndarray = None, + y_pred_classes: np.ndarray = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + a: np.ndarray + The explanation to be evaluated on an instance-basis. + i: int + The index of the current instance. + a_label: np.ndarray + The discretised attribution labels. + y_pred_classes: np,ndarray + The class predictions of the complete input dataset. + + Returns + ------- + float + The evaluation results. + """ + # Metric logic. + pred_a = y_pred_classes[i] + same_a = np.argwhere(a == a_label).flatten() + diff_a = same_a[same_a != i] + pred_same_a = y_pred_classes[diff_a] + + if len(same_a) == 0: + return 0 + return np.sum(pred_same_a == pred_a) / len(diff_a) + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch = super().batch_preprocess(data_batch) @@ -257,7 +293,6 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch.update(custom_batch) return data_batch - @no_type_check def evaluate_batch( self, *, @@ -267,17 +302,9 @@ def evaluate_batch( y_pred_classes, **_, ) -> List[float]: - # TODO: vectorize - retval = [] - for a, i, a_label in zip(a_batch, i_batch, a_label_batch): - pred_a = y_pred_classes[i] - same_a = np.argwhere(a == a_label).flatten() - diff_a = same_a[same_a != i] - pred_same_a = y_pred_classes[diff_a] - - if len(same_a) == 0: - retval.append(0.0) - else: - retval.append(np.sum(pred_same_a == pred_a) / len(diff_a)) - - return retval + # TODO: for performance reasons vectorize this for-loop + + return [ + self.evaluate_instance(a, i, a_label, y_pred_classes) + for a, i, a_label in zip(a_batch, i_batch, a_label_batch) + ] From d08dc0fec43c0f1b3d0c9cd3ebf5c889099450a3 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 17:04:06 +0200 Subject: [PATCH 15/58] minify git diff --- quantus/metrics/robustness/continuity.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index be4184da5..349ddf411 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -317,12 +317,9 @@ def evaluate_instance( x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) prediction_changed = ( - len( - self.changed_prediction_indices( - model, np.expand_dims(x, 0), x_input - ) - ) - != 0 + self.return_nan_when_prediction_changes + and model.predict(np.expand_dims(x, 0)).argmax(axis=-1)[0] + != model.predict(x_input).argmax(axis=-1)[0] ) # Taking the first element, since a_perturbed will be expanded to a batch dimension # not expected by the current index management functions. @@ -361,11 +358,19 @@ def evaluate_instance( # not expected by the current index management functions. # a_perturbed = utils.expand_attribution_channel(a_perturbed, x_input)[0] + if self.normalise: + a_perturbed_patch = self.normalise_func( + a_perturbed_patch.flatten(), **self.normalise_func_kwargs + ) + + if self.abs: + a_perturbed_patch = np.abs(a_perturbed_patch.flatten()) + # Sum attributions for patch. - patch_sum = float(np.sum(a_perturbed_patch)) + patch_sum = float(sum(a_perturbed_patch)) results[ix_patch].append(patch_sum) - return {k: v for k, v in results.items()} + return results def custom_preprocess( self, From 69897489d1aa6935197369755a89003eb24b8126 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 17:22:33 +0200 Subject: [PATCH 16/58] mypy fixes --- quantus/metrics/base.py | 108 +++++++++++----------- quantus/metrics/localisation/focus.py | 1 + quantus/metrics/robustness/consistency.py | 1 + quantus/metrics/robustness/continuity.py | 2 +- 4 files changed, 57 insertions(+), 55 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index be079654e..75582a8c1 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -8,7 +8,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, Dict, Sequence, Optional, ClassVar, Generator, Set +from typing import Any, Callable, Dict, Sequence, ClassVar, Generator, Set import matplotlib.pyplot as plt import numpy as np @@ -44,10 +44,10 @@ def __init__( abs: bool, normalise: bool, normalise_func: Callable, - normalise_func_kwargs: Optional[Dict[str, Any]], + normalise_func_kwargs: Dict[str, ...] | None, return_aggregate: bool, aggregate_func: Callable, - default_plot_func: Optional[Callable], + default_plot_func: Callable[[...], None] | None, disable_warnings: bool, display_progressbar: bool, **kwargs, @@ -118,19 +118,19 @@ def __call__( self, model, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: Optional[np.ndarray], - channel_first: Optional[bool], - explain_func: Optional[Callable], - explain_func_kwargs: Optional[Dict[str, Any]], - model_predict_kwargs: Optional[Dict], - softmax: Optional[bool], - device: Optional[str] = None, + y_batch: np.ndarray | None, + a_batch: np.ndarray | None, + s_batch: np.ndarray | None, + channel_first: bool | None, + explain_func: Callable[[...], None] | None, + explain_func_kwargs: Dict[str, ...] | None, + model_predict_kwargs: Dict[str, ...] | None, + softmax: bool | None, + device: str | None = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, + custom_batch: Any = None, **kwargs, - ) -> Any: + ): """ This implementation represents the main logic of the metric and makes the class object callable. It completes batch-wise evaluation of explanations (a_batch) with respect to input data (x_batch), @@ -306,17 +306,17 @@ def general_preprocess( self, model, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: Optional[np.ndarray], - channel_first: Optional[bool], + y_batch: np.ndarray | None, + a_batch: np.ndarray | None, + s_batch: np.ndarray | None, + channel_first: bool | None, explain_func: Callable, - explain_func_kwargs: Optional[Dict[str, Any]], - model_predict_kwargs: Optional[Dict], + explain_func_kwargs: Dict[str, ...] | None, + model_predict_kwargs: Dict[str, ...] | None, softmax: bool, - device: Optional[str], - custom_batch: Optional[np.ndarray], - ) -> Dict[str, Any]: + device: str | None, + custom_batch: np.ndarray | None, + ) -> Dict[str, ...]: """ Prepares all necessary variables for evaluation. @@ -438,11 +438,11 @@ def custom_preprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], + y_batch: np.ndarray | None, + a_batch: np.ndarray | None, s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> Optional[Dict[str, Any]]: + custom_batch: np.ndarray | None, + ) -> Dict[str, ...] | None: """ Implement this method if you need custom preprocessing of data, model alteration or simply for creating/initialising additional @@ -486,10 +486,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: Optional[np.ndarray], - >>> a_batch: Optional[np.ndarray], + >>> y_batch: np.ndarray | None, + >>> a_batch: np.ndarray | None, >>> s_batch: np.ndarray, - >>> custom_batch: Optional[np.ndarray], + >>> custom_batch: np.ndarray | None, >>> ) -> Dict[str, Any]: >>> return {'my_new_variable': np.mean(x_batch)} >>> @@ -497,8 +497,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: Optional[np.ndarray], - >>> a: Optional[np.ndarray], + >>> y: np.ndarray | None, + >>> a: np.ndarray | None, >>> s: np.ndarray, >>> my_new_variable: np.float, >>> ) -> float: @@ -508,10 +508,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: Optional[np.ndarray], - >>> a_batch: Optional[np.ndarray], + >>> y_batch: np.ndarray | None, + >>> a_batch: np.ndarray | None, >>> s_batch: np.ndarray, - >>> custom_batch: Optional[np.ndarray], + >>> custom_batch: np.ndarray | None, >>> ) -> Dict[str, Any]: >>> return {'my_new_variable_batch': np.arange(len(x_batch))} >>> @@ -519,8 +519,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: Optional[np.ndarray], - >>> a: Optional[np.ndarray], + >>> y: np.ndarray | None, + >>> a: np.ndarray | None, >>> s: np.ndarray, >>> my_new_variable: np.int, >>> ) -> float: @@ -531,10 +531,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: Optional[np.ndarray], - >>> a_batch: Optional[np.ndarray], + >>> y_batch: np.ndarray | None, + >>> a_batch: np.ndarray | None, >>> s_batch: np.ndarray, - >>> custom_batch: Optional[np.ndarray], + >>> custom_batch: np.ndarray | None, >>> ) -> Dict[str, Any]: >>> return {'x_batch': x_batch - np.mean(x_batch, axis=0)} >>> @@ -542,8 +542,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: Optional[np.ndarray], - >>> a: Optional[np.ndarray], + >>> y: np.ndarray | None, + >>> a: np.ndarray | None, >>> s: np.ndarray, >>> ) -> float: @@ -553,10 +553,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: Optional[np.ndarray], - >>> a_batch: Optional[np.ndarray], + >>> y_batch: np.ndarray | None, + >>> a_batch: np.ndarray | None, >>> s_batch: np.ndarray, - >>> custom_batch: Optional[np.ndarray], + >>> custom_batch: np.ndarray | None, >>> ) -> None: >>> if np.any(np.all(a_batch < 0, axis=0)): >>> raise ValueError("Attributions must not be all negative") @@ -569,8 +569,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: Optional[np.ndarray], - >>> a: Optional[np.ndarray], + >>> y: np.ndarray | None, + >>> a: np.ndarray | None, >>> s: np.ndarray, >>> ) -> float: @@ -581,11 +581,11 @@ def custom_postprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], + y_batch: np.ndarray | None, + a_batch: np.ndarray | None, s_batch: np.ndarray, **kwargs, - ) -> Optional[Any]: + ): """ Implement this method if you need custom postprocessing of results or additional attributes. @@ -614,9 +614,9 @@ def custom_postprocess( def generate_batches( self, - data: Dict[str, Any], + data: Dict[str, ...], batch_size: int, - ) -> Generator[Dict[str, Any], None, None]: + ) -> Generator[Dict[str, ...], None, None]: """ Creates iterator to iterate over all batched instances in data dictionary. Each iterator output element is a keyword argument dictionary with @@ -693,7 +693,7 @@ def generate_batches( def plot( self, - plot_func: Optional[Callable] = None, + plot_func: Callable[[...], None] | None = None, show: bool = True, path_to_save: str | None = None, *args, @@ -778,7 +778,7 @@ def all_results(self): ) return self.all_evaluation_scores - def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: """ Does computationally heavy pre-processing on batch level to avoid OOM. By default, will only generate explanations if missing, in case metric requires custom diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 6437b98e4..18354af91 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -382,6 +382,7 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: ] return quandrant_a + @no_type_check def evaluate_batch( self, *, a_batch: np.ndarray, c_batch: np.ndarray, **_ ) -> List[float]: diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index b4f0ddb93..641473379 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -293,6 +293,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch.update(custom_batch) return data_batch + @no_type_check def evaluate_batch( self, *, diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 349ddf411..38fbd2ab9 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -284,7 +284,7 @@ def __call__( def evaluate_instance( self, model: ModelInterface, x: np.ndarray, y: np.ndarray - ) -> Dict[str, int | float]: + ) -> Dict[int, List[Any]]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. From 4d1d25eb223104db4490f4ed457623a5f354d3f9 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 30 Aug 2023 19:19:19 +0200 Subject: [PATCH 17/58] add missing docstrings --- .../faithfulness/monotonicity_correlation.py | 15 +++++--------- quantus/metrics/faithfulness/sufficiency.py | 15 ++++++++++++++ quantus/metrics/localisation/auc.py | 15 ++++++++++++++ quantus/metrics/localisation/focus.py | 14 +++++++++++++ quantus/metrics/localisation/pointing_game.py | 15 ++++++++++++++ .../localisation/relevance_mass_accuracy.py | 14 +++++++++++++ .../localisation/relevance_rank_accuracy.py | 15 ++++++++++++++ .../localisation/top_k_intersection.py | 15 ++++++++++++++ quantus/metrics/randomisation/random_logit.py | 20 ++++++++++++++++++- quantus/metrics/robustness/continuity.py | 17 ++++++++++++++++ 10 files changed, 144 insertions(+), 11 deletions(-) diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 73049d9f2..4be495761 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -399,7 +399,7 @@ def evaluate_batch( Parameters ---------- model: ModelInterface - A ModelInterface that is subject to explanation. + A model that is subject to explanation. x_batch: np.ndarray The input to be evaluated on a batch-basis. y_batch: np.ndarray @@ -412,19 +412,14 @@ def evaluate_batch( List[float] The evaluation results. """ - # Asserts. - asserts.assert_features_in_step( - features_in_step=self.features_in_step, - input_shape=x_batch.shape[2:], - ) # Evaluate explanations. return [ self.evaluate_instance( - model=model, - x=x, - y=y, - a=a, + model, + x, + y, + a, ) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 0af6ff069..a45ee33a4 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -309,7 +309,22 @@ def evaluate_batch( self, *, i_batch, a_sim_vector_batch, y_pred_classes, **_ ) -> List[float]: """ + TODO: write meaningful docstring about what does it compute. + Parameters + ---------- + i_batch: + The index of the current instance. + a_sim_vector_batch: + The custom input to be evaluated on an instance-basis. + y_pred_classes: + The class predictions of the complete input dataset. + _: + unused. + + Returns + ------- + """ return [ diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 8e2ffd240..401962c8c 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -286,5 +286,20 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: + """ + + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: + A np.ndarray which contains segmentation masks that matches the input. + _: + unused. + + Returns + ------- + + """ # TODO: for performance reasons replace for-loop with vectorized dispatch. return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 18354af91..e78d86f69 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -386,4 +386,18 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: def evaluate_batch( self, *, a_batch: np.ndarray, c_batch: np.ndarray, **_ ) -> List[float]: + """ + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + c_batch: + The custom input to be evaluated on an batch-basis. + _: + unused. + + Returns + ------- + + """ return [self.evaluate_instance(a, c) for a, c in zip(a_batch, c_batch)] diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index b39268ed8..cd9c3685e 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -309,4 +309,19 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: + """ + + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: + A np.ndarray which contains segmentation masks that matches the input. + _: + unused. + + Returns + ------- + + """ return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 4016327e4..6cc7a3a67 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -299,4 +299,18 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: + """ + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: + A np.ndarray which contains segmentation masks that matches the input. + _: + unused. + + Returns + ------- + + """ return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index c0323b425..b2d068e24 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -308,5 +308,20 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: + """ + + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: + A np.ndarray which contains segmentation masks that matches the input. + _: + unused. + + Returns + ------- + + """ # TODO: for performance reasons, this method should be vectorized. return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index f9af43d3b..1b184b5f0 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -318,4 +318,19 @@ def custom_preprocess( def evaluate_batch( self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ ) -> List[float]: + """ + + Parameters + ---------- + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: + A np.ndarray which contains segmentation masks that matches the input. + _: + unused. + + Returns + ------- + + """ return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index 3986fc554..79f97e544 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -333,7 +333,25 @@ def evaluate_batch( a_batch: np.ndarray, **_, ) -> List[float]: - # TODO: for performance reasons vectorize this for-loop + """ + + Parameters + ---------- + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + _: + unused. + + Returns + ------- + + """ return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 38fbd2ab9..f7f8abc08 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -445,4 +445,21 @@ def evaluate_batch( y_batch: np.ndarray, **_, ) -> List[Dict[str, int]]: + """ + + Parameters + ---------- + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + _: + unused. + + Returns + ------- + + """ return [self.evaluate_instance(model, x, y) for x, y in zip(x_batch, y_batch)] From 58d19ff683112ce591568df3773f53e0247cbd04 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:05:37 +0200 Subject: [PATCH 18/58] * run black --- quantus/__init__.py | 2 +- quantus/evaluation.py | 12 +++-------- quantus/functions/explanation_func.py | 1 - quantus/functions/loss_func.py | 2 +- quantus/functions/normalise_func.py | 4 ++-- quantus/functions/similarity_func.py | 2 +- quantus/helpers/__init__.py | 2 +- quantus/helpers/model/models.py | 1 - quantus/helpers/model/pytorch_model.py | 29 +++++++++++++++----------- quantus/helpers/model/tf_model.py | 1 - quantus/helpers/plotting.py | 1 - quantus/helpers/utils.py | 3 --- 12 files changed, 26 insertions(+), 34 deletions(-) diff --git a/quantus/__init__.py b/quantus/__init__.py index c29dc0a35..485b9d851 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -26,4 +26,4 @@ from quantus.helpers.model import * # Expose the helpers utils. -from quantus.helpers.utils import * \ No newline at end of file +from quantus.helpers.utils import * diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 75d628ee3..8c37a24a7 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {'call_kwargs_empty': {}} + call_kwargs = {"call_kwargs_empty": {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -92,11 +92,9 @@ def evaluate( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." for method, value in xai_methods.items(): - results[method] = {} if callable(value): - explain_funcs[method] = value explain_func = value @@ -116,7 +114,6 @@ def evaluate( asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch) elif isinstance(value, Dict): - if explain_func_kwargs is not None: warnings.warn( "Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods." @@ -140,7 +137,6 @@ def evaluate( a_batch = value else: - raise TypeError( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." ) @@ -148,12 +144,10 @@ def evaluate( if explain_func_kwargs is None: explain_func_kwargs = {} - for (metric, metric_func) in metrics.items(): - + for metric, metric_func in metrics.items(): results[method][metric] = {} - for (call_kwarg_str, call_kwarg) in call_kwargs.items(): - + for call_kwarg_str, call_kwarg in call_kwargs.items(): if progress: print( f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..." diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index d34260f3f..fbc4de9ba 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -385,7 +385,6 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": - num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index 69181e1f3..bd0723645 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) + return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0) # If no need to normalise, return (a-b)^2. return np.average(((a - b) ** 2), axis=0) diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 2d518aad4..5ad7f8145 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate( # Check that square root of the second momment estimatte is nonzero. second_moment_sqrt = np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) if all(second_moment_sqrt != 0): a /= np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) else: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 88d19a9a7..93ffe8832 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -145,7 +145,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs + **kwargs, ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/__init__.py b/quantus/helpers/__init__.py index eafc68d63..e4d00f57e 100644 --- a/quantus/helpers/__init__.py +++ b/quantus/helpers/__init__.py @@ -8,4 +8,4 @@ # Import files dependent on package installations. __EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain") -__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") \ No newline at end of file +__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 38d0ca004..2feb32c55 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,7 +12,6 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): - import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index cd422931f..b79442cb8 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -86,8 +86,11 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info("Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) + logging.info( + "Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", + named_module[0], + ) break return linear_model @@ -118,8 +121,10 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info("Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") + logging.info( + "Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -133,12 +138,14 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info("Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output.") + logging.info( + "Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -337,7 +344,6 @@ def add_mean_shift_to_first_layer( The resulting model with a shifted first layer. """ with torch.no_grad(): - new_model = deepcopy(self.model) modules = [l for l in new_model.named_modules()] @@ -364,7 +370,6 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 7a1769147..c72b06860 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -360,7 +360,6 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/plotting.py b/quantus/helpers/plotting.py index 4fa305713..fd504556a 100644 --- a/quantus/helpers/plotting.py +++ b/quantus/helpers/plotting.py @@ -245,7 +245,6 @@ def plot_model_parameter_randomisation_experiment( plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method) else: - layers = list(results.keys()) scores = {k: [] for k in layers} # samples = len(results) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 196bdb33f..0ec29da91 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -764,7 +764,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: - # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " @@ -773,7 +772,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence ) ) elif x_subshapes.count(a_shape) > 1: - # Check that attribution dimensions are (unique) subdimensions of inputs. # Consider potentially expanded dims in attributions. @@ -783,7 +781,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1) ] if a_subshapes.count(a_shape) == 1: - # Inferring channel shape. for dim in range(len(np.shape(a_batch)[1:]) + 1): if a_shape == np.shape(a_batch)[1:][dim:]: From 07790d52de665cfd4d60e2321130d6a11ccb75e6 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:13:51 +0200 Subject: [PATCH 19/58] Revert "* run black" This reverts commit 58d19ff683112ce591568df3773f53e0247cbd04. --- quantus/__init__.py | 2 +- quantus/evaluation.py | 12 ++++++++--- quantus/functions/explanation_func.py | 1 + quantus/functions/loss_func.py | 2 +- quantus/functions/normalise_func.py | 4 ++-- quantus/functions/similarity_func.py | 2 +- quantus/helpers/__init__.py | 2 +- quantus/helpers/model/models.py | 1 + quantus/helpers/model/pytorch_model.py | 29 +++++++++++--------------- quantus/helpers/model/tf_model.py | 1 + quantus/helpers/plotting.py | 1 + quantus/helpers/utils.py | 3 +++ 12 files changed, 34 insertions(+), 26 deletions(-) diff --git a/quantus/__init__.py b/quantus/__init__.py index 485b9d851..c29dc0a35 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -26,4 +26,4 @@ from quantus.helpers.model import * # Expose the helpers utils. -from quantus.helpers.utils import * +from quantus.helpers.utils import * \ No newline at end of file diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 8c37a24a7..75d628ee3 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {"call_kwargs_empty": {}} + call_kwargs = {'call_kwargs_empty': {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -92,9 +92,11 @@ def evaluate( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." for method, value in xai_methods.items(): + results[method] = {} if callable(value): + explain_funcs[method] = value explain_func = value @@ -114,6 +116,7 @@ def evaluate( asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch) elif isinstance(value, Dict): + if explain_func_kwargs is not None: warnings.warn( "Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods." @@ -137,6 +140,7 @@ def evaluate( a_batch = value else: + raise TypeError( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." ) @@ -144,10 +148,12 @@ def evaluate( if explain_func_kwargs is None: explain_func_kwargs = {} - for metric, metric_func in metrics.items(): + for (metric, metric_func) in metrics.items(): + results[method][metric] = {} - for call_kwarg_str, call_kwarg in call_kwargs.items(): + for (call_kwarg_str, call_kwarg) in call_kwargs.items(): + if progress: print( f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..." diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index fbc4de9ba..d34260f3f 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -385,6 +385,7 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": + num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index bd0723645..69181e1f3 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0) + return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) # If no need to normalise, return (a-b)^2. return np.average(((a - b) ** 2), axis=0) diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 5ad7f8145..2d518aad4 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate( # Check that square root of the second momment estimatte is nonzero. second_moment_sqrt = np.sqrt( - np.sum(a**2, axis=normalise_axes, keepdims=True) + np.sum(a ** 2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) if all(second_moment_sqrt != 0): a /= np.sqrt( - np.sum(a**2, axis=normalise_axes, keepdims=True) + np.sum(a ** 2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) else: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 93ffe8832..88d19a9a7 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -145,7 +145,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs, + **kwargs ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/__init__.py b/quantus/helpers/__init__.py index e4d00f57e..eafc68d63 100644 --- a/quantus/helpers/__init__.py +++ b/quantus/helpers/__init__.py @@ -8,4 +8,4 @@ # Import files dependent on package installations. __EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain") -__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") +__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") \ No newline at end of file diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 2feb32c55..38d0ca004 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,6 +12,7 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): + import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index b79442cb8..cd422931f 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -86,11 +86,8 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info( - "Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", - named_module[0], - ) + logging.info("Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) break return linear_model @@ -121,10 +118,8 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info( - "Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." - ) + logging.info("Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -138,14 +133,12 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info( - "Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output." - ) + logging.info("Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output.") return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -344,6 +337,7 @@ def add_mean_shift_to_first_layer( The resulting model with a shifted first layer. """ with torch.no_grad(): + new_model = deepcopy(self.model) modules = [l for l in new_model.named_modules()] @@ -370,6 +364,7 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: + """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index c72b06860..7a1769147 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -360,6 +360,7 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: + """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/plotting.py b/quantus/helpers/plotting.py index fd504556a..4fa305713 100644 --- a/quantus/helpers/plotting.py +++ b/quantus/helpers/plotting.py @@ -245,6 +245,7 @@ def plot_model_parameter_randomisation_experiment( plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method) else: + layers = list(results.keys()) scores = {k: [] for k in layers} # samples = len(results) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 0ec29da91..196bdb33f 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -764,6 +764,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: + # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " @@ -772,6 +773,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence ) ) elif x_subshapes.count(a_shape) > 1: + # Check that attribution dimensions are (unique) subdimensions of inputs. # Consider potentially expanded dims in attributions. @@ -781,6 +783,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1) ] if a_subshapes.count(a_shape) == 1: + # Inferring channel shape. for dim in range(len(np.shape(a_batch)[1:]) + 1): if a_shape == np.shape(a_batch)[1:][dim:]: From 858d6befce43aef3cc737928bcf6d2f1704f32a3 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:17:02 +0200 Subject: [PATCH 20/58] * run black --- quantus/evaluation.py | 1 - quantus/helpers/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 75d628ee3..0da4bbac0 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -116,7 +116,6 @@ def evaluate( asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch) elif isinstance(value, Dict): - if explain_func_kwargs is not None: warnings.warn( "Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods." diff --git a/quantus/helpers/__init__.py b/quantus/helpers/__init__.py index eafc68d63..e4d00f57e 100644 --- a/quantus/helpers/__init__.py +++ b/quantus/helpers/__init__.py @@ -8,4 +8,4 @@ # Import files dependent on package installations. __EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain") -__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") \ No newline at end of file +__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") From c4ab6f8233693269309f48520cfbb6c3012ec32d Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:25:08 +0200 Subject: [PATCH 21/58] * --- quantus/metrics/robustness/continuity.py | 1 - 1 file changed, 1 deletion(-) diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index f7f8abc08..e1b237e4e 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -5,7 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from __future__ import annotations import itertools from typing import Any, Callable, Dict, List, Optional import numpy as np From 6eef978ac34b913891ccc45666b27bd7a78c8f86 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:28:34 +0200 Subject: [PATCH 22/58] * --- quantus/metrics/randomisation/model_parameter_randomisation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 4991bf821..6d4bcac32 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -391,4 +391,4 @@ def compute_correlation_per_sample( return corr_coeffs def evaluate_batch(self, *args, **kwargs): - raise RuntimeError("This is unexpected.") + raise RuntimeError("`evaluate_batch` must never be called for `ModelParameterRandomisation`.") From 9ca2b7dd331a5c610a9ed6acdf8f6956c3a6ef46 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:30:50 +0200 Subject: [PATCH 23/58] * code review comments --- .../metrics/randomisation/model_parameter_randomisation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 6d4bcac32..ca598d269 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -295,9 +295,9 @@ def __call__( # Generate an explanation with perturbed model. a_batch_perturbed = self.explain_batch( - random_layer_model, - x_batch, - y_batch, + model=random_layer_model, + x_batch=x_batch, + y_batch=y_batch, ) batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) From 9c4b34ee521fdfe60ef1bacec494c3f444eec275 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:38:08 +0200 Subject: [PATCH 24/58] * code review comments --- quantus/metrics/axiomatic/completeness.py | 4 ++-- quantus/metrics/axiomatic/non_sensitivity.py | 1 - quantus/metrics/complexity/complexity.py | 4 ++-- quantus/metrics/complexity/effective_complexity.py | 4 ++-- quantus/metrics/complexity/sparseness.py | 5 +++-- quantus/metrics/faithfulness/faithfulness_estimate.py | 1 - quantus/metrics/localisation/auc.py | 3 ++- quantus/metrics/localisation/relevance_rank_accuracy.py | 3 ++- quantus/metrics/robustness/consistency.py | 3 ++- 9 files changed, 15 insertions(+), 13 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index b7814d13b..d069200c9 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -317,7 +317,6 @@ def evaluate_batch( **_, ) -> List[bool]: """ - Checks if sum of attributions is equal to the difference between original prediction and prediction on baseline value. @@ -343,7 +342,8 @@ def evaluate_batch( """ - # TODO. For performance gains, replace the for loop below with vectorisation. + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [ self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index e96f2bfcf..0b2a20d43 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -371,7 +371,6 @@ def evaluate_batch( **kwargs, ) -> List[int]: """ - Count the number of features in each explanation, for which model is not sensitive. Parameters diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 1d9a0b575..ea154972a 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -255,7 +255,6 @@ def evaluate_batch( self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. Parameters @@ -274,5 +273,6 @@ def evaluate_batch( List of floats. """ - # TODO. For performance gains, replace the for loop below with vectorisation. + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 878f35ea1..ed26473af 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -251,7 +251,6 @@ def evaluate_instance( def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: """ - Count how many attributions exceed the threshold `eps` Parameters @@ -269,6 +268,7 @@ def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: List of integers. """ - # TODO. For performance gains, replace the for loop below with vectorisation. + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [self.evaluate_instance(a) for a in a_batch] diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 7de93fae8..7712368d9 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -282,7 +282,8 @@ def evaluate_batch( List of floats. """ - - # TODO. For performance gains, replace the for loop below with vectorisation. + + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 1a6b04d19..8d54ba1e0 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -375,7 +375,6 @@ def evaluate_batch( **_, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. Parameters diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 401962c8c..2457d71e6 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -301,5 +301,6 @@ def evaluate_batch( ------- """ - # TODO: for performance reasons replace for-loop with vectorized dispatch. + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index b2d068e24..812ce6af7 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -323,5 +323,6 @@ def evaluate_batch( ------- """ - # TODO: for performance reasons, this method should be vectorized. + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 641473379..bb2cad597 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -303,7 +303,8 @@ def evaluate_batch( y_pred_classes, **_, ) -> List[float]: - # TODO: for performance reasons vectorize this for-loop + # TODO: For performance gains, replace the for loop below with vectorisation. + # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 return [ self.evaluate_instance(a, i, a_label, y_pred_classes) From c1eeae39acf61c3fc33453f69f517b3a92f21dd7 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:43:26 +0200 Subject: [PATCH 25/58] * code review comments --- quantus/metrics/faithfulness/faithfulness_correlation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 1d9c43f6c..a2061667f 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -384,7 +384,6 @@ def evaluate_batch( **_, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. Parameters From 77337fbc0cd766f0f19947527d98482c70ddc875 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:49:20 +0200 Subject: [PATCH 26/58] * code review comments --- quantus/metrics/faithfulness/road.py | 1 + 1 file changed, 1 insertion(+) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index c864eb484..d4d6d78f2 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -352,6 +352,7 @@ def custom_postprocess( } def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" data_batch = super().batch_preprocess(data_batch) # Infer the size of attributions. self.a_size = data_batch["a_batch"][0, :, :].size From 02850eff429896196b85cb756b874cf8baabd5df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20Hedstr=C3=B6m?= Date: Mon, 9 Oct 2023 13:04:40 +0200 Subject: [PATCH 27/58] Update base.py --- quantus/metrics/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 75582a8c1..165c851ed 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -261,8 +261,8 @@ def __call__( "Specify an 'aggregate_func' (Callable) to aggregate evaluation scores." ) - # Append content of last results to all results. - self.all_evaluation_scores.append(self.evaluation_scores) + # Append the content of the last results to all results. + self.all_evaluation_scores.extend(self.evaluation_scores) return self.evaluation_scores From ecda3c54555848cdf49d91b58c998d6383421ec2 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 10 Oct 2023 17:20:30 +0200 Subject: [PATCH 28/58] * code review comments --- quantus/__init__.py | 2 +- quantus/evaluation.py | 11 +- quantus/functions/explanation_func.py | 1 - quantus/functions/loss_func.py | 2 +- quantus/functions/normalise_func.py | 4 +- quantus/functions/similarity_func.py | 2 +- quantus/helpers/model/models.py | 1 - quantus/helpers/model/pytorch_model.py | 29 ++-- quantus/helpers/model/tf_model.py | 1 - quantus/helpers/perturbation_utils.py | 78 ++++++++++ quantus/helpers/plotting.py | 1 - quantus/helpers/utils.py | 3 - quantus/metrics/axiomatic/completeness.py | 41 ++--- quantus/metrics/axiomatic/input_invariance.py | 53 ++----- quantus/metrics/axiomatic/non_sensitivity.py | 26 ++-- quantus/metrics/base.py | 55 +++++-- quantus/metrics/base_batched.py | 12 -- quantus/metrics/base_perturbed.py | 143 ------------------ quantus/metrics/complexity/complexity.py | 13 +- .../complexity/effective_complexity.py | 20 +-- quantus/metrics/complexity/sparseness.py | 16 +- .../faithfulness/faithfulness_correlation.py | 35 ++--- .../faithfulness/faithfulness_estimate.py | 35 ++--- quantus/metrics/faithfulness/infidelity.py | 34 ++--- quantus/metrics/faithfulness/irof.py | 34 ++--- quantus/metrics/faithfulness/monotonicity.py | 38 ++--- .../faithfulness/monotonicity_correlation.py | 38 ++--- .../metrics/faithfulness/pixel_flipping.py | 32 ++-- .../faithfulness/region_perturbation.py | 32 ++-- quantus/metrics/faithfulness/road.py | 44 +++--- quantus/metrics/faithfulness/selectivity.py | 34 ++--- quantus/metrics/faithfulness/sensitivity_n.py | 34 ++--- quantus/metrics/faithfulness/sufficiency.py | 29 +++- .../localisation/attribution_localisation.py | 20 ++- quantus/metrics/localisation/auc.py | 17 ++- quantus/metrics/localisation/focus.py | 17 ++- quantus/metrics/localisation/pointing_game.py | 15 +- .../localisation/relevance_mass_accuracy.py | 16 +- .../localisation/relevance_rank_accuracy.py | 17 ++- .../localisation/top_k_intersection.py | 13 +- .../model_parameter_randomisation.py | 4 +- quantus/metrics/robustness/avg_sensitivity.py | 53 +++---- quantus/metrics/robustness/consistency.py | 35 ++++- quantus/metrics/robustness/continuity.py | 37 ++--- .../robustness/local_lipschitz_estimate.py | 44 +++--- quantus/metrics/robustness/max_sensitivity.py | 41 ++--- .../robustness/relative_input_stability.py | 36 ++--- .../robustness/relative_output_stability.py | 31 ++-- .../relative_representation_stability.py | 32 ++-- 49 files changed, 647 insertions(+), 714 deletions(-) create mode 100644 quantus/helpers/perturbation_utils.py delete mode 100644 quantus/metrics/base_perturbed.py diff --git a/quantus/__init__.py b/quantus/__init__.py index c29dc0a35..485b9d851 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -26,4 +26,4 @@ from quantus.helpers.model import * # Expose the helpers utils. -from quantus.helpers.utils import * \ No newline at end of file +from quantus.helpers.utils import * diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 0da4bbac0..8c37a24a7 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {'call_kwargs_empty': {}} + call_kwargs = {"call_kwargs_empty": {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -92,11 +92,9 @@ def evaluate( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." for method, value in xai_methods.items(): - results[method] = {} if callable(value): - explain_funcs[method] = value explain_func = value @@ -139,7 +137,6 @@ def evaluate( a_batch = value else: - raise TypeError( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." ) @@ -147,12 +144,10 @@ def evaluate( if explain_func_kwargs is None: explain_func_kwargs = {} - for (metric, metric_func) in metrics.items(): - + for metric, metric_func in metrics.items(): results[method][metric] = {} - for (call_kwarg_str, call_kwarg) in call_kwargs.items(): - + for call_kwarg_str, call_kwarg in call_kwargs.items(): if progress: print( f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..." diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index d34260f3f..fbc4de9ba 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -385,7 +385,6 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": - num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index 69181e1f3..bd0723645 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) + return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0) # If no need to normalise, return (a-b)^2. return np.average(((a - b) ** 2), axis=0) diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 2d518aad4..5ad7f8145 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate( # Check that square root of the second momment estimatte is nonzero. second_moment_sqrt = np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) if all(second_moment_sqrt != 0): a /= np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) else: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 88d19a9a7..93ffe8832 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -145,7 +145,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs + **kwargs, ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 38d0ca004..2feb32c55 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,7 +12,6 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): - import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index cd422931f..b79442cb8 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -86,8 +86,11 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info("Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) + logging.info( + "Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", + named_module[0], + ) break return linear_model @@ -118,8 +121,10 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info("Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") + logging.info( + "Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -133,12 +138,14 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info("Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output.") + logging.info( + "Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -337,7 +344,6 @@ def add_mean_shift_to_first_layer( The resulting model with a shifted first layer. """ with torch.no_grad(): - new_model = deepcopy(self.model) modules = [l for l in new_model.named_modules()] @@ -364,7 +370,6 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 7a1769147..c72b06860 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -360,7 +360,6 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/perturbation_utils.py b/quantus/helpers/perturbation_utils.py new file mode 100644 index 000000000..841633d21 --- /dev/null +++ b/quantus/helpers/perturbation_utils.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import List, TYPE_CHECKING, Callable, Mapping, Protocol +import numpy as np +import functools + + +if TYPE_CHECKING: + from quantus.helpers.model.model_interface import ModelInterface + + class PerturbFunc(Protocol): + def __call__( + self, + arr: np.ndarray, + indices: np.ndarray, + indexed_axes: np.ndarray, + **kwargs, + ) -> np.ndarray: + ... + + +def make_perturb_func( + perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs +) -> PerturbFunc | functools.partial: + """A utility function to save few lines of code during perturbation metric initialization.""" + if perturb_func_kwargs is None: + perturb_func_kwargs = {} + + kwargs.update(perturb_func_kwargs) + + return functools.partial(perturb_func, **perturb_func_kwargs) + + +def make_changed_prediction_indices_func( + return_nan_when_prediction_changes: bool, +) -> Callable[[ModelInterface, np.ndarray, np.ndarray], List[int]]: + """A utility function to improve static analysis.""" + return functools.partial( + changed_prediction_indices, + return_nan_when_prediction_changes=return_nan_when_prediction_changes, + ) + + +def changed_prediction_indices( + model: ModelInterface, + x_batch: np.ndarray, + x_perturbed: np.ndarray, + return_nan_when_prediction_changes: bool, +) -> List[int]: + """ + Find indices in batch, for which predicted label has changed after applying perturbation. + If metric `return_nan_when_prediction_changes` is False, will return empty list. + + Parameters + ---------- + return_nan_when_prediction_changes: + Instance attribute of perturbation metrics. + model: + x_batch: + Batch of original inputs provided by user. + x_perturbed: + Batch of inputs after applying perturbation. + + Returns + ------- + + changed_idx: + List of indices in batch, for which predicted label has changed afer. + + """ + + if not return_nan_when_prediction_changes: + return [] + + labels_before = model.predict(x_batch).argmax(axis=-1) + labels_after = model.predict(x_perturbed).argmax(axis=-1) + changed_idx = np.reshape(np.argwhere(labels_before != labels_after), -1) + return changed_idx.tolist() diff --git a/quantus/helpers/plotting.py b/quantus/helpers/plotting.py index 4fa305713..fd504556a 100644 --- a/quantus/helpers/plotting.py +++ b/quantus/helpers/plotting.py @@ -245,7 +245,6 @@ def plot_model_parameter_randomisation_experiment( plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method) else: - layers = list(results.keys()) scores = {k: [] for k in layers} # samples = len(results) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 196bdb33f..0ec29da91 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -764,7 +764,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: - # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " @@ -773,7 +772,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence ) ) elif x_subshapes.count(a_shape) > 1: - # Check that attribution dimensions are (unique) subdimensions of inputs. # Consider potentially expanded dims in attributions. @@ -783,7 +781,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1) ] if a_subshapes.count(a_shape) == 1: - # Inferring channel shape. for dim in range(len(np.shape(a_batch)[1:]) + 1): if a_shape == np.shape(a_batch)[1:][dim:]: diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index d069200c9..0bf17301c 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -13,16 +13,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class Completeness(PerturbationMetric): +class Completeness(Metric): """ Implementation of Completeness test by Sundararajan et al., 2017, also referred to as Summation to Delta by Shrikumar et al., 2017 and Conservation by @@ -64,7 +65,7 @@ def __init__( normalise_func_kwargs: Optional[Dict[str, Any]] = None, output_func: Optional[Callable] = lambda x: x, perturb_baseline: str = "black", - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, aggregate_func: Callable = np.mean, @@ -113,21 +114,11 @@ def __init__( """ if normalise_func is None: normalise_func = normalise_by_max - - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -140,6 +131,9 @@ def __init__( if output_func is None: output_func = lambda x: x self.output_func = output_func + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -288,10 +282,7 @@ def evaluate_instance( The evaluation results. """ x_baseline = self.perturb_func( - arr=x, - indices=np.arange(0, x.size), - indexed_axes=np.arange(0, x.ndim), - **self.perturb_func_kwargs, + arr=x, indices=np.arange(0, x.size), indexed_axes=np.arange(0, x.ndim) ) # Predict on input. @@ -314,12 +305,11 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[bool]: """ - Checks if sum of attributions is equal to the difference between original prediction and - prediction on baseline value. - + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -331,7 +321,7 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + kwargs: Unused kwargs. Returns @@ -339,12 +329,9 @@ def evaluate_batch( scores_batch: List of booleans. - - """ - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 + return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index b022dd621..06ce7f8bc 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -14,16 +14,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_shift, perturb_batch -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class InputInvariance(PerturbationMetric): +class InputInvariance(Metric): """ Implementation of Completeness test by Kindermans et al., 2017. @@ -55,6 +56,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, input_shift: Union[int, float] = -1, + perturb_func=baseline_replacement_by_shift, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, aggregate_func: Callable = np.mean, @@ -101,19 +103,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - perturb_func = baseline_replacement_by_shift - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["input_shift"] = input_shift - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -122,11 +116,15 @@ def __init__( **kwargs, ) + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, input_shift=input_shift + ) + # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=("input shift 'input_shift'"), + sensitive_params="input shift 'input_shift'", citation=( "Kindermans Pieter-Jan, Hooker Sarah, Adebayo Julius, Alber Maximilian, Schütt Kristof T., " "Dähne Sven, Erhan Dumitru and Kim Been. 'THE (UN)RELIABILITY OF SALIENCY METHODS' Article (2017)." @@ -279,11 +277,10 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) # Get input shift. - input_shift = self.perturb_func_kwargs["input_shift"] + input_shift = self.perturb_func.keywords["input_shift"] x_shifted = model.shape_input( x=x_shifted, shape=x_shifted.shape, @@ -307,37 +304,19 @@ def evaluate_batch( return score - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, - ) -> None: + def custom_preprocess(self, *args, **kwargs) -> None: """ - Implementation of custom_preprocess_batch. + Additional explain_func assert, as the one in prepare() won't be executed when a_batch != None. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility to the inheriting metric to use for evaluation, can hold any variable. + args: + Unused. + kwargs: + Unused. Returns ------- None """ - # Additional explain_func assert, as the one in prepare() won't be - # executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 0b2a20d43..6436df4e7 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -14,16 +14,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class NonSensitivity(PerturbationMetric): +class NonSensitivity(Metric): """ Implementation of NonSensitivity by Nguyen at el., 2020. @@ -62,7 +63,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, perturb_baseline: str = "black", - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, aggregate_func: Callable = np.mean, @@ -114,16 +115,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - - perturb_func = perturb_func - perturb_func_kwargs = perturb_func_kwargs - super().__init__( abs=abs, normalise=normalise, @@ -143,6 +134,9 @@ def __init__( self.eps = eps self.n_samples = n_samples self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -310,7 +304,6 @@ def evaluate_instance( arr=x, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) # Predict on perturbed input x. @@ -371,7 +364,8 @@ def evaluate_batch( **kwargs, ) -> List[int]: """ - Count the number of features in each explanation, for which model is not sensitive. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -396,6 +390,6 @@ def evaluate_batch( """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 75582a8c1..97dd663c6 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -8,7 +8,18 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, Dict, Sequence, ClassVar, Generator, Set +from typing import ( + Any, + Callable, + Dict, + Sequence, + ClassVar, + Generator, + Set, + TypedDict, + TypeVar, +) +import logging import matplotlib.pyplot as plt import numpy as np @@ -26,6 +37,9 @@ ) from quantus.helpers.model.model_interface import ModelInterface +D = TypeVar("D", bound=Dict[str, Any]) +log = logging.getLogger(__name__) + class Metric: """ @@ -250,9 +264,9 @@ def __call__( self.evaluation_scores = [ self.aggregate_func(self.evaluation_scores) ] - except: - print( - "The aggregation of evaluation scores failed. Check that " + except Exception as ex: + log.error( + f"The aggregation of evaluation scores failed with {ex}. Check that " "'aggregate_func' supplied is appropriate for the data " "in 'evaluation_scores'." ) @@ -614,9 +628,9 @@ def custom_postprocess( def generate_batches( self, - data: Dict[str, ...], + data: D, batch_size: int, - ) -> Generator[Dict[str, ...], None, None]: + ) -> Generator[D, None, None]: """ Creates iterator to iterate over all batched instances in data dictionary. Each iterator output element is a keyword argument dictionary with @@ -764,7 +778,7 @@ def get_params(self) -> Dict[str, Any]: @property def last_results(self): - print( + log.warning( "Warning: 'last_results' has been renamed to 'evaluation_scores'. " "'last_results' is removed in current version." ) @@ -772,18 +786,15 @@ def last_results(self): @property def all_results(self): - print( + log.warning( "Warning: 'all_results' has been renamed to 'all_evaluation_scores'. " "'all_results' is removed in current version." ) return self.all_evaluation_scores def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: - """ - Does computationally heavy pre-processing on batch level to avoid OOM. - By default, will only generate explanations if missing, in case metric requires custom - data, it must be overriden in respective metric. - """ + """If `data_batch` has no `a_batch`, will compute explanations. This needs to be done on batch level to avoid OOM.""" + x_batch = data_batch["x_batch"] a_batch = data_batch.get("a_batch") @@ -799,8 +810,26 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: # TODO: we must not modify global state during evaluation. self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) + custom_batch = self.custom_batch_preprocess(data_batch) + data_batch.update(custom_batch) return data_batch + def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + """ + Implement this method if you need custom preprocessing of data + or simply for creating/initialising additional attributes or assertions + before a `data_batch` can be evaluated. + + Parameters + ---------- + data_batch + + Returns + ------- + + """ + return {} + def explain_batch( self, model: ModelInterface, diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index 7ce4fda67..a4c692aa3 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -8,7 +8,6 @@ import warnings from quantus.metrics.base import Metric -from quantus.metrics.base_perturbed import PerturbationMetric """Aliases to smoothen transition to uniform metric API.""" @@ -23,14 +22,3 @@ def __new__(cls, *args, **kwargs): "BatchedMetric was deprecated, since it is just an alias to Metric. Please subclass Metric directly." ) super().__new__(*args, **kwargs) - - -class BatchedPerturbationMetric(PerturbationMetric, abc.ABC): - """Alias to quantus.PerturbationMetric, will be removed in next major release.""" - - def __new__(cls, *args, **kwargs): - warnings.warn( - "BatchedPerturbationMetric was deprecated, " - "since it is just an alias to Metric. Please subclass PerturbationMetric directly." - ) - super().__new__(*args, **kwargs) diff --git a/quantus/metrics/base_perturbed.py b/quantus/metrics/base_perturbed.py deleted file mode 100644 index 70db9f732..000000000 --- a/quantus/metrics/base_perturbed.py +++ /dev/null @@ -1,143 +0,0 @@ -"""This module implements the base class for creating evaluation metrics.""" -# This file is part of Quantus. -# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. -# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. -# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . -# Quantus project URL: . - -from abc import ABC -from typing import Any, Callable, Dict, Optional, List -import warnings - -import numpy as np - -from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base import Metric - - -class PerturbationMetric(Metric, ABC): - """ - Implementation base PertubationMetric class. - - Metric categories such as Faithfulness and Robustness share certain characteristics when it comes to perturbations. - As follows, this metric class is created which has additional attributes for perturbations. - - Attributes: - - name: The name of the metric. - - data_applicability: The data types that the metric implementation currently supports. - - model_applicability: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - def __init__( - self, - abs: bool, - normalise: bool, - normalise_func: Callable, - normalise_func_kwargs: Optional[Dict[str, Any]], - perturb_func: Callable, - perturb_func_kwargs: Optional[Dict[str, Any]], - return_aggregate: bool, - aggregate_func: Callable, - default_plot_func: Optional[Callable], - disable_warnings: bool, - display_progressbar: bool, - **kwargs, - ): - """ - Initialise the PerturbationMetric base class. - - Parameters - ---------- - Parameters - ---------- - abs: boolean - Indicates whether absolute operation is applied on the attribution. - normalise: boolean - Indicates whether normalise operation is applied on the attribution. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call. - perturb_func: callable - Input perturbation function. - perturb_func_kwargs: dict, optional - Keyword arguments to be passed to perturb_func. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed. - kwargs: optional - Keyword arguments. - """ - - # Initialize super-class with passed parameters - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - # TODO: do we really need separate 150+ lines long class just to reuse 4 lines of code? - # Save perturbation metric attributes. - self.perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - self.perturb_func_kwargs = perturb_func_kwargs - - def changed_prediction_indices( - self, model: ModelInterface, x_batch: np.ndarray, x_perturbed: np.ndarray - ) -> List[int]: - """ - Find indices in batch, for which predicted label has changed after applying perturbation. - If metric has no `return_nan_when_prediction_changes` attribute, or it is False, will return empty list. - - Parameters - ---------- - model: - x_batch: - Batch of original inputs provided by user. - x_perturbed: - Batch of inputs after applying perturbation. - - Returns - ------- - - changed_idx: - List of indices in batch, for which predicted label has changed afer. - - """ - - if hasattr(self, "return_nan_when_prediction_changes"): - attr_name = "return_nan_when_prediction_changes" - elif hasattr(self, "_return_nan_when_prediction_changes"): - attr_name = "_return_nan_when_prediction_changes" - else: - warnings.warn( - "Called changed_prediction_indices(), from a metric, " - "without `return_nan_when_prediction_changes` instance attribute, this is unexpected." - ) - return [] - - if not getattr(self, attr_name): - return [] - - labels_before = model.predict(x_batch).argmax(axis=-1) - labels_after = model.predict(x_perturbed).argmax(axis=-1) - changed_idx = np.reshape(np.argwhere(labels_before != labels_after), -1) - return changed_idx.tolist() diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index ea154972a..e24feae6e 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -252,10 +252,11 @@ def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: return scipy.stats.entropy(pk=a) def evaluate_batch( - self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ + self, *args, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -263,7 +264,9 @@ def evaluate_batch( The input to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -273,6 +276,4 @@ def evaluate_batch( List of floats. """ - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 - return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] + return [self.evaluate_instance(x=x, a=a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index ed26473af..834a6a5c0 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -228,10 +228,7 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - a: np.ndarray, - ) -> int: + def evaluate_instance(self, a: np.ndarray) -> int: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -249,26 +246,25 @@ def evaluate_instance( a = a.flatten() return int(np.sum(a > self.eps)) - def evaluate_batch(self, *, a_batch: np.ndarray, **_) -> List[int]: + def evaluate_batch(self, *args, a_batch: np.ndarray, **kwargs) -> List[int]: """ - Count how many attributions exceed the threshold `eps` + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: - Unused + args: + Unused. + kwargs: + Unused. Returns ------- scores_batch: List of integers. - """ - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 - return [self.evaluate_instance(a) for a in a_batch] diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 7712368d9..d939c1839 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -261,10 +261,11 @@ def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: return score def evaluate_batch( - self, *, x_batch: np.ndarray, a_batch: np.ndarray, **_ + self, *args, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -272,7 +273,9 @@ def evaluate_batch( The input to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -280,10 +283,5 @@ def evaluate_batch( scores_batch: List of floats. - """ - - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 - - return [self.evaluate_instance(x, a) for x, a in zip(x_batch, a_batch)] + return [self.evaluate_instance(x=x, a=a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index a2061667f..6e95db4f7 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -16,16 +16,17 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class FaithfulnessCorrelation(PerturbationMetric): +class FaithfulnessCorrelation(Metric): """ Implementation of faithfulness correlation by Bhatt et al., 2020. @@ -69,7 +70,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, @@ -122,21 +123,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -151,6 +142,9 @@ def __init__( self.similarity_func = similarity_func self.nr_runs = nr_runs self.subset_size = subset_size + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -320,7 +314,6 @@ def evaluate_instance( arr=x, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -376,15 +369,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -396,7 +390,9 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -404,9 +400,8 @@ def evaluate_batch( scores_batch: List of floats. - """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 8d54ba1e0..50f1d1e2d 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -16,16 +16,17 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class FaithfulnessEstimate(PerturbationMetric): +class FaithfulnessEstimate(Metric): """ Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. @@ -58,7 +59,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -109,20 +110,10 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, normalise_func_kwargs=normalise_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, @@ -137,6 +128,9 @@ def __init__( similarity_func = correlation_pearson self.similarity_func = similarity_func self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -311,7 +305,6 @@ def evaluate_instance( arr=x, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -367,15 +360,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -387,7 +381,9 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -395,10 +391,9 @@ def evaluate_batch( scores_batch: List of floats. - """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index b801544ee..8dee28faf 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -15,16 +15,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class Infidelity(PerturbationMetric): +class Infidelity(Metric): """ Implementation of Infidelity by Yeh et al., 2019. @@ -66,7 +67,9 @@ def __init__( n_perturb_samples: int = 10, abs: bool = False, normalise: bool = False, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func: Optional[ + Callable[[np.ndarray], np.ndarray] + ] = baseline_replacement_by_indices, normalise_func_kwargs: Optional[Dict[str, Any]] = None, perturb_func: Callable = None, perturb_baseline: str = "black", @@ -122,13 +125,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, @@ -157,6 +153,9 @@ def __init__( self.perturb_patch_sizes = perturb_patch_sizes self.n_perturb_samples = n_perturb_samples self.nr_channels = None + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -343,7 +342,6 @@ def evaluate_instance( arr=x_perturbed_pad, indices=patch_slice, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) # Remove padding. @@ -413,15 +411,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -433,7 +432,9 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -441,10 +442,9 @@ def evaluate_batch( scores_batch: List of floats. - """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 4a8d7f843..93e403d92 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -16,16 +16,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class IROF(PerturbationMetric): +class IROF(Metric): """ Implementation of IROF (Iterative Removal of Features) by Rieger at el., 2020. @@ -63,7 +64,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "mean", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, @@ -111,21 +112,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -137,6 +128,9 @@ def __init__( # Save metric-specific attributes. self.segmentation_method = segmentation_method self.nr_channels = None + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -318,7 +312,6 @@ def evaluate_instance( arr=x_prev_perturbed, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change( x=x_prev_perturbed, x_perturbed=x_perturbed @@ -377,15 +370,16 @@ def get_aoc_score(self): def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -397,7 +391,9 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + kwargs: + Unused. + args: Unused. Returns @@ -408,6 +404,6 @@ def evaluate_batch( """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 7dec97bce..08492e7f1 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -16,16 +16,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class Monotonicity(PerturbationMetric): +class Monotonicity(Metric): """ Implementation of Monotonicity metric by Arya at el., 2019. @@ -64,7 +65,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -112,21 +113,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -137,6 +128,9 @@ def __init__( # Save metric-specific attributes. self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -295,7 +289,7 @@ def evaluate_instance( # Copy the input x but fill with baseline values. baseline_value = utils.get_baseline_value( - value=self.perturb_func_kwargs["perturb_baseline"], + value=self.perturb_func.keywords["perturb_baseline"], arr=x, return_shape=x.shape, # TODO. Double-check this over using = (1,). ) @@ -310,7 +304,6 @@ def evaluate_instance( arr=x_baseline, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). @@ -359,15 +352,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[bool]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -379,7 +373,9 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - _: + args: + Unused. + kwargs: Unused. Returns @@ -387,10 +383,8 @@ def evaluate_batch( scores_batch: List of floats. - """ - return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 4be495761..4a219d8bd 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -16,16 +16,17 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_spearman -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class MonotonicityCorrelation(PerturbationMetric): +class MonotonicityCorrelation(Metric): """ Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. @@ -65,7 +66,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "uniform", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -120,14 +121,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, @@ -151,6 +144,9 @@ def __init__( self.eps = eps self.nr_samples = nr_samples self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -331,7 +327,6 @@ def evaluate_instance( arr=x, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -386,15 +381,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -406,6 +402,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. Returns ------- @@ -416,10 +416,10 @@ def evaluate_batch( # Evaluate explanations. return [ self.evaluate_instance( - model, - x, - y, - a, + model=model, + x=x, + y=y, + a=a, ) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 8a962fb97..4157d8e3a 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -18,16 +18,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class PixelFlipping(PerturbationMetric): +class PixelFlipping(Metric): """ Implementation of Pixel-Flipping experiment by Bach et al., 2015. @@ -61,7 +62,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -112,14 +113,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - if default_plot_func is None: default_plot_func = plotting.plot_pixel_flipping_experiment @@ -128,8 +121,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -141,12 +132,15 @@ def __init__( # Save metric-specific attributes. self.features_in_step = features_in_step self.return_auc_per_sample = return_auc_per_sample + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=("baseline value 'perturb_baseline'"), + sensitive_params="baseline value 'perturb_baseline'", citation=( "Bach, Sebastian, et al. 'On pixel-wise explanations for non-linear classifier" " decisions by layer - wise relevance propagation.' PloS one 10.7 (2015) " @@ -307,7 +301,6 @@ def evaluate_instance( arr=x_perturbed, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -367,15 +360,16 @@ def get_auc_score(self): def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float | np.ndarray]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -394,6 +388,6 @@ def evaluate_batch( The evaluation results. """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 4d4d01aeb..9fe653640 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -18,16 +18,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class RegionPerturbation(PerturbationMetric): +class RegionPerturbation(Metric): """ Implementation of Region Perturbation by Samek et al., 2015. @@ -69,7 +70,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -122,14 +123,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - if default_plot_func is None: default_plot_func = plotting.plot_region_perturbation_experiment @@ -152,6 +145,9 @@ def __init__( self.patch_size = patch_size self.order = order.lower() self.regions_evaluation = regions_evaluation + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. asserts.assert_attributions_order(order=self.order) @@ -396,7 +392,6 @@ def evaluate_instance( arr=x_perturbed_pad, indices=patch_slice, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) # Remove padding. @@ -423,15 +418,16 @@ def get_auc_score(self): def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[List[float]]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -443,6 +439,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. Returns ------- @@ -450,6 +450,6 @@ def evaluate_batch( The evaluation results. """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index d4d6d78f2..1f860fda3 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -14,16 +14,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import noisy_linear_imputation -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class ROAD(PerturbationMetric): +class ROAD(Metric): """ Implementation of ROAD evaluation strategy by Rong et al., 2022. @@ -62,8 +63,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, - perturb_baseline: str = "black", + perturb_func: Callable = noisy_linear_imputation, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, aggregate_func: Callable = np.mean, @@ -89,9 +89,6 @@ def __init__( perturb_func: callable Input perturbation function. If None, the default value is used, default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="black". perturb_func_kwargs: dict Keyword arguments to be passed to perturb_func, default={}. return_aggregate: boolean @@ -110,20 +107,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = noisy_linear_imputation - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["noise"] = noise - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -137,15 +125,15 @@ def __init__( percentages = list(range(1, 100, 2)) self.percentages = percentages self.a_size = None + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, noise=noise + ) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline', perturbation function 'perturb_func', " - "percentage of pixels k removed per iteration 'percentage_in_step'" - ), + sensitive_params="perturbation function 'perturb_func'", data_domain_applicability=( f"Also, the current implementation only works for 3-dimensional (image) data." ), @@ -281,12 +269,10 @@ def evaluate_instance( The output to be evaluated on an instance-basis. a: np.ndarray The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. Returns ------- - : list + list: The evaluation results. """ # Order indices. @@ -300,7 +286,6 @@ def evaluate_instance( x_perturbed = self.perturb_func( arr=x, indices=top_k_indices, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -360,7 +345,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, @@ -368,7 +353,8 @@ def evaluate_batch( **kwargs, ): """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -380,6 +366,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. Returns ------- @@ -387,6 +377,6 @@ def evaluate_batch( The evaluation results. """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 8c4f96853..92a5712b8 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -17,16 +17,17 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class Selectivity(PerturbationMetric): +class Selectivity(Metric): """ Implementation of Selectivity test by Montavon et al., 2018. @@ -66,7 +67,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -114,14 +115,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - if default_plot_func is None: default_plot_func = plotting.plot_selectivity_experiment @@ -130,8 +123,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -142,6 +133,9 @@ def __init__( # Save metric-specific attributes. self.patch_size = patch_size + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -353,7 +347,6 @@ def evaluate_instance( arr=x_perturbed_pad, indices=patch_slice, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) # Remove padding. @@ -382,15 +375,16 @@ def get_auc_score(self): def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[List[float]]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -402,6 +396,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. Returns ------- @@ -409,6 +407,6 @@ def evaluate_batch( The evaluation results. """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 2d265ed46..150b807d2 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -17,16 +17,17 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class SensitivityN(PerturbationMetric): +class SensitivityN(Metric): """ Implementation of Sensitivity-N test by Ancona et al., 2019. @@ -65,7 +66,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, @@ -118,14 +119,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - perturb_func = perturb_func - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - if default_plot_func is None: default_plot_func = plotting.plot_sensitivity_n_experiment @@ -134,8 +127,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -150,6 +141,9 @@ def __init__( self.similarity_func = similarity_func self.n_max_percentage = n_max_percentage self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -322,7 +316,6 @@ def evaluate_instance( arr=x_perturbed, indices=a_ix, indexed_axes=self.a_axes, - **self.perturb_func_kwargs, ) warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) @@ -435,15 +428,16 @@ def custom_postprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[Dict[str, List[float]]]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -455,6 +449,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. Returns ------- @@ -463,6 +461,6 @@ def evaluate_batch( """ return [ - self.evaluate_instance(model, x, y, a) + self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index a45ee33a4..c784f1ba2 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -279,6 +279,17 @@ def evaluate_instance( return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + """ + + + Parameters + ---------- + data_batch + + Returns + ------- + + """ data_batch = super().batch_preprocess(data_batch) model = data_batch["model"] x_batch = data_batch["x_batch"] @@ -304,13 +315,13 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: data_batch.update(custom_batch) return data_batch - @no_type_check def evaluate_batch( - self, *, i_batch, a_sim_vector_batch, y_pred_classes, **_ + self, *args, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. - TODO: write meaningful docstring about what does it compute. Parameters ---------- i_batch: @@ -319,15 +330,21 @@ def evaluate_batch( The custom input to be evaluated on an instance-basis. y_pred_classes: The class predictions of the complete input dataset. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- + evaluation_scores: + List of measured sufficiency for each entry in the batch. """ return [ - self.evaluate_instance(i, a_sim_vector, y_pred_classes) + self.evaluate_instance( + i=i, a_sim_vector=a_sim_vector, y_pred_classes=y_pred_classes + ) for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) ] diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index edb3de75a..3ebfbe4fe 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -327,10 +327,16 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *, x_batch: np.ndarray, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, + *args, + x_batch: np.ndarray, + a_batch: np.ndarray, + s_batch: np.ndarray, + **kwargs, ) -> List[float]: """ - TODO: write meaningful docstring about what does it compute. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -340,13 +346,17 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray A np.ndarray which contains segmentation masks that matches the input. + args: + Unused. + kwargs: + Unused. Returns ------- - evaluation_scores: list - a list of floats. + retval: + Evaluation result for batch. """ return [ - self.evaluate_instance(x, a, s) + self.evaluate_instance(x=x, a=a, s=s) for x, a, s in zip(x_batch, a_batch, s_batch) ] diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 2457d71e6..b565e8fae 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -284,9 +284,11 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -294,13 +296,14 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- - + retval: + Evaluation result for batch. """ - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 - return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] + return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index e78d86f69..f2ba5a4db 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -382,22 +382,27 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: ] return quandrant_a - @no_type_check def evaluate_batch( - self, *, a_batch: np.ndarray, c_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, c_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + Parameters ---------- a_batch: A np.ndarray which contains pre-computed attributions i.e., explanations. c_batch: The custom input to be evaluated on an batch-basis. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- - + retval: + Evaluation result for batch. """ - return [self.evaluate_instance(a, c) for a, c in zip(a_batch, c_batch)] + return [self.evaluate_instance(a=a, c=c) for a, c in zip(a_batch, c_batch)] diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index cd9c3685e..0fc204ace 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -307,9 +307,11 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -317,11 +319,14 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- - + retval: + Evaluation result for batch. """ - return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] + return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 6cc7a3a67..5a309f77c 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -297,20 +297,26 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + Parameters ---------- a_batch: A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- - + retval: + A list of Any with the evaluation scores for the batch. """ - return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] + return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 812ce6af7..b89f7bb61 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -306,9 +306,11 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -316,13 +318,14 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - _: - unused. + args: + Unused. + kwargs: + Unused Returns ------- - + retval: + Evaluation result for batch. """ - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 - return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] + return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 1b184b5f0..daa66cc8c 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -316,9 +316,11 @@ def custom_preprocess( ) def evaluate_batch( - self, *, a_batch: np.ndarray, s_batch: np.ndarray, **_ + self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -326,11 +328,14 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- - + retval: + Evaluation result for batch. """ return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index ca598d269..d04211a15 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -391,4 +391,6 @@ def compute_correlation_per_sample( return corr_coeffs def evaluate_batch(self, *args, **kwargs): - raise RuntimeError("`evaluate_batch` must never be called for `ModelParameterRandomisation`.") + raise RuntimeError( + "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + ) diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 675e98f81..ac3cfc138 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -1,4 +1,5 @@ """This module contains the implementation of the Avg-Sensitivity metric.""" +import functools # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. @@ -16,16 +17,20 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import uniform_noise, perturb_batch from quantus.functions.similarity_func import difference -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_perturb_func, + make_changed_prediction_indices_func, +) -class AvgSensitivity(PerturbationMetric): +class AvgSensitivity(Metric): """ Implementation of Avg-Sensitivity by Yeh at el., 2019. @@ -62,7 +67,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = uniform_noise, lower_bound: float = 0.2, upper_bound: Optional[float] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, @@ -122,14 +127,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = uniform_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["lower_bound"] = lower_bound - perturb_func_kwargs["upper_bound"] = upper_bound - super().__init__( abs=abs, normalise=normalise, @@ -159,7 +156,16 @@ def __init__( if norm_denominator is None: norm_denominator = norm_func.fro_norm self.norm_denominator = norm_denominator - self.return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.changed_prediction_indices = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) + self.mean_func = np.mean if return_nan_when_prediction_changes else np.nanmean + self.perturb_func = make_perturb_func( + perturb_func, + perturb_func_kwargs, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) # Asserts and warnings. if not self.disable_warnings: @@ -307,8 +313,8 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. - s_batch: np.ndarray - The segmentation to be evaluated on an instance-basis. + kwargs: + Unused. Returns ------- @@ -325,16 +331,10 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) - changed_prediction_indices = ( - np.argwhere( - model.predict(x_batch).argmax(axis=-1) - != model.predict(x_perturbed).argmax(axis=-1) - ).reshape(-1) - if self.return_nan_when_prediction_changes - else [] + changed_prediction_indices = self.changed_prediction_indices( + model, x_batch, x_perturbed ) for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed): @@ -348,10 +348,7 @@ def evaluate_batch( # Measure similarity for each instance separately. for instance_id in range(batch_size): - if ( - self.return_nan_when_prediction_changes - and instance_id in changed_prediction_indices - ): + if instance_id in changed_prediction_indices: similarities[instance_id, step_id] = np.nan continue @@ -363,8 +360,8 @@ def evaluate_batch( denominator = self.norm_denominator(a=a_batch[instance_id].flatten()) sensitivities_norm = numerator / denominator similarities[instance_id, step_id] = sensitivities_norm - mean_func = np.mean if self.return_nan_when_prediction_changes else np.nanmean - return mean_func(similarities, axis=1) + + return self.mean_func(similarities, axis=1) def custom_preprocess( self, diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index bb2cad597..22283b2ed 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -269,6 +269,16 @@ def evaluate_instance( return np.sum(pred_same_a == pred_a) / len(diff_a) def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: + """ + + Parameters + ---------- + data_batch + + Returns + ------- + + """ data_batch = super().batch_preprocess(data_batch) model = data_batch["model"] @@ -296,15 +306,32 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: @no_type_check def evaluate_batch( self, - *, + *args, a_batch: np.ndarray, i_batch: np.ndarray, a_label_batch: np.ndarray, y_pred_classes, - **_, + **kwargs, ) -> List[float]: - # TODO: For performance gains, replace the for loop below with vectorisation. - # https://github.com/understandable-machine-intelligence-lab/Quantus/issues/299 + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + args: + Unused. + a_batch: + + i_batch + a_label_batch + y_pred_classes + kwargs + + Returns + ------- + + """ return [ self.evaluate_instance(a, i, a_label, y_pred_classes) diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index e1b237e4e..88d25db71 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -16,16 +16,17 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import translation_x_direction from quantus.functions.similarity_func import lipschitz_constant -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import make_perturb_func -class Continuity(PerturbationMetric): +class Continuity(Metric): """ Implementation of the Continuity test by Montavon et al., 2018. @@ -65,7 +66,7 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = translation_x_direction, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -123,20 +124,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = translation_x_direction - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_baseline"] = perturb_baseline - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -154,6 +146,9 @@ def __init__( self.nr_patches: Optional[int] = None self.dx = None self.return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -311,7 +306,6 @@ def evaluate_instance( indices=np.arange(0, x.size), indexed_axes=np.arange(0, x.ndim), perturb_dx=dx_step, - **self.perturb_func_kwargs, ) x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) @@ -438,13 +432,15 @@ def aggregated_score(self): def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, - **_, + **kwargs, ) -> List[Dict[str, int]]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -454,11 +450,16 @@ def evaluate_batch( A np.ndarray which contains the input data that are explained. y_batch: A np.ndarray which contains the output labels that are explained. - _: - unused. + kwargs: + Unused. + args: + Unused. Returns ------- """ - return [self.evaluate_instance(model, x, y) for x, y in zip(x_batch, y_batch)] + return [ + self.evaluate_instance(model=model, x=x, y=y) + for x, y in zip(x_batch, y_batch) + ] diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index 6c722fb09..d6d0d749d 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -15,7 +15,7 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import gaussian_noise, perturb_batch from quantus.functions.similarity_func import lipschitz_constant, distance_euclidean -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, @@ -23,8 +23,13 @@ EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_perturb_func, + make_changed_prediction_indices_func, +) + -class LocalLipschitzEstimate(PerturbationMetric): +class LocalLipschitzEstimate(Metric): """ Implementation of the Local Lipschitz Estimate (or Stability) test by Alvarez-Melis et al., 2018a, 2018b. @@ -62,7 +67,7 @@ def __init__( nr_samples: int = 200, abs: bool = False, normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = gaussian_noise, normalise_func_kwargs: Optional[Dict[str, Any]] = None, perturb_func: Callable = None, perturb_mean: float = 0.0, @@ -126,21 +131,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = gaussian_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["perturb_mean"] = perturb_mean - perturb_func_kwargs["perturb_std"] = perturb_std - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -163,7 +158,16 @@ def __init__( if norm_denominator is None: norm_denominator = distance_euclidean self.norm_denominator = norm_denominator - self.return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, + perturb_func_kwargs, + perturb_mean=perturb_mean, + perturb_std=perturb_std, + ) + self.changed_prediction_indices_func = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) + self.max_func = np.max if return_nan_when_prediction_changes else np.nanmax # Asserts and warnings. if not self.disable_warnings: @@ -331,10 +335,9 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) - changed_prediction_indices = self.changed_prediction_indices( + changed_prediction_indices = self.changed_prediction_indices_func( model, x_batch, x_perturbed ) @@ -348,10 +351,7 @@ def evaluate_batch( a_perturbed = self.explain_batch(model, x_perturbed, y_batch) # Measure similarity for each instance separately. for instance_id in range(batch_size): - if ( - self.return_nan_when_prediction_changes - and instance_id in changed_prediction_indices - ): + if instance_id in changed_prediction_indices: similarities[instance_id, step_id] = np.nan continue @@ -364,8 +364,8 @@ def evaluate_batch( norm_denominator=self.norm_denominator, ) similarities[instance_id, step_id] = similarity - max_func = np.max if self.return_nan_when_prediction_changes else np.nanmax - return max_func(similarities, axis=1) + + return self.max_func(similarities, axis=1) def custom_preprocess( self, diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 9fa8935b9..721f3b58b 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -16,16 +16,20 @@ from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import uniform_noise, perturb_batch from quantus.functions.similarity_func import difference -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_changed_prediction_indices_func, + make_perturb_func, +) -class MaxSensitivity(PerturbationMetric): +class MaxSensitivity(Metric): """ Implementation of Max-Sensitivity by Yeh at el., 2019. @@ -62,7 +66,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = uniform_noise, lower_bound: float = 0.2, upper_bound: Optional[float] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, @@ -122,14 +126,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_max - if perturb_func is None: - perturb_func = uniform_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {} - perturb_func_kwargs["lower_bound"] = lower_bound - perturb_func_kwargs["upper_bound"] = upper_bound - super().__init__( abs=abs, normalise=normalise, @@ -158,8 +154,18 @@ def __init__( if norm_denominator is None: norm_denominator = norm_func.fro_norm + self.norm_denominator = norm_denominator - self.return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, + perturb_func_kwargs, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + self.changed_prediction_indices_func = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) + self.max_func = np.max if return_nan_when_prediction_changes else np.nanmax # Asserts and warnings. if not self.disable_warnings: @@ -323,10 +329,9 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) - changed_prediction_indices = self.changed_prediction_indices( + changed_prediction_indices = self.changed_prediction_indices_func( model, x_batch, x_perturbed ) @@ -341,10 +346,7 @@ def evaluate_batch( # Measure similarity for each instance separately. for instance_id in range(batch_size): - if ( - self.return_nan_when_prediction_changes - and instance_id in changed_prediction_indices - ): + if instance_id in changed_prediction_indices: similarities[instance_id, step_id] = np.nan continue @@ -357,8 +359,7 @@ def evaluate_batch( sensitivities_norm = numerator / denominator similarities[instance_id, step_id] = sensitivities_norm - max_func = np.max if self.return_nan_when_prediction_changes else np.nanmax - return max_func(similarities, axis=1) + return self.max_func(similarities, axis=1) def custom_preprocess( self, diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 4cb73cff6..70faf507d 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -15,7 +15,7 @@ from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch @@ -25,9 +25,13 @@ ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_perturb_func, + make_changed_prediction_indices_func, +) -class RelativeInputStability(PerturbationMetric): +class RelativeInputStability(Metric): """ Relative Input Stability leverages the stability of an explanation with respect to the change in the input data. @@ -59,7 +63,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = None, + perturb_func: Callable = uniform_noise, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, aggregate_func: Optional[Callable[[np.ndarray], np.float]] = np.mean, @@ -106,12 +110,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_average_second_moment_estimate - if perturb_func is None: - perturb_func = uniform_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {"upper_bound": 0.2} - super().__init__( abs=abs, normalise=normalise, @@ -128,7 +126,12 @@ def __init__( ) self._nr_samples = nr_samples self._eps_min = eps_min - self._return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, upper_bound=0.2 + ) + self.changed_prediction_indices_func = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) if not self.disable_warnings: warn_parameterisation( @@ -303,7 +306,6 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) # Generate explanations for perturbed input. @@ -315,19 +317,13 @@ def evaluate_batch( ) ris_batch[index] = ris - # We're done with this sample if `return_nan_when_prediction_changes`==False. - if not self._return_nan_when_prediction_changes: - continue - # If perturbed input caused change in prediction, then it's RIS=nan. - - changed_prediction_indices = self.changed_prediction_indices( + changed_prediction_indices = self.changed_prediction_indices_func( model, x_batch, x_perturbed ) - if len(changed_prediction_indices) == 0: - continue - ris_batch[index, changed_prediction_indices] = np.nan + if len(changed_prediction_indices) != 0: + ris_batch[index, changed_prediction_indices] = np.nan # Compute RIS. result = np.max(ris_batch, axis=0) diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 5cb2913db..2f3729b61 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -8,27 +8,29 @@ from typing import TYPE_CHECKING, Optional, Callable, Dict, List import numpy as np -from functools import partial if TYPE_CHECKING: import tensorflow as tf import torch from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch -from quantus.helpers.utils import expand_attribution_channel from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_perturb_func, + make_changed_prediction_indices_func, +) -class RelativeOutputStability(PerturbationMetric): +class RelativeOutputStability(Metric): """ Relative Output Stability leverages the stability of an explanation with respect to the change in the output logits. @@ -62,7 +64,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = None, + perturb_func: Callable = uniform_noise, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, aggregate_func: Optional[Callable[[np.ndarray], np.float]] = np.mean, @@ -109,12 +111,6 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_average_second_moment_estimate - if perturb_func is None: - perturb_func = uniform_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {"upper_bound": 0.2} - super().__init__( abs=abs, normalise=normalise, @@ -131,7 +127,12 @@ def __init__( ) self._nr_samples = nr_samples self._eps_min = eps_min - self._return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, upper_bound=0.2 + ) + self.changed_prediction_indices_func = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) if not self.disable_warnings: warn_parameterisation( @@ -300,9 +301,6 @@ def evaluate_batch( """ batch_size = x_batch.shape[0] - _explain_func = partial( - self.explain_func, model=model.get_model(), **self.explain_func_kwargs - ) # Execute forward pass on provided inputs. logits = model.predict(x_batch) @@ -316,7 +314,6 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) # Generate explanations for perturbed input. @@ -331,7 +328,7 @@ def evaluate_batch( ros_batch[index] = ros # If perturbed input caused change in prediction, then it's ROS=nan. - changed_prediction_indices = self.changed_prediction_indices( + changed_prediction_indices = self.changed_prediction_indices_func( model, x_batch, x_perturbed ) diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index b94ee5ba0..3a02a2a59 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional, Callable, Dict, List import numpy as np -from functools import partial if TYPE_CHECKING: import tensorflow as tf @@ -16,20 +15,23 @@ from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base_perturbed import PerturbationMetric +from quantus.metrics.base import Metric from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.functions.perturb_func import uniform_noise, perturb_batch -from quantus.helpers.utils import expand_attribution_channel from quantus.helpers.enums import ( ModelType, DataType, ScoreDirection, EvaluationCategory, ) +from quantus.helpers.perturbation_utils import ( + make_perturb_func, + make_changed_prediction_indices_func, +) -class RelativeRepresentationStability(PerturbationMetric): +class RelativeRepresentationStability(Metric): """ Relative Representation Stability leverages the stability of an explanation with respect to the change in the output logits. @@ -64,7 +66,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = None, + perturb_func: Callable = uniform_noise, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, aggregate_func: Optional[Callable[[np.ndarray], np.ndarray]] = np.mean, @@ -75,7 +77,7 @@ def __init__( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, return_nan_when_prediction_changes: bool = True, - **kwargs: Dict[str, ...], + **kwargs, ): """ Parameters @@ -117,19 +119,11 @@ def __init__( if normalise_func is None: normalise_func = normalise_by_average_second_moment_estimate - if perturb_func is None: - perturb_func = uniform_noise - - if perturb_func_kwargs is None: - perturb_func_kwargs = {"upper_bound": 0.2} - super().__init__( abs=abs, normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -146,7 +140,12 @@ def __init__( self._layer_names = layer_names self._layer_indices = layer_indices - self._return_nan_when_prediction_changes = return_nan_when_prediction_changes + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, upper_bound=0.2 + ) + self.changed_prediction_indices_func = make_changed_prediction_indices_func( + return_nan_when_prediction_changes + ) if not self.disable_warnings: warn_parameterisation( @@ -329,7 +328,6 @@ def evaluate_batch( indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), indexed_axes=np.arange(0, x_batch[0].ndim), arr=x_batch, - **self.perturb_func_kwargs, ) # Generate explanations for perturbed input. @@ -349,7 +347,7 @@ def evaluate_batch( ) rrs_batch[index] = rrs # If perturbed input caused change in prediction, then it's RRS=nan. - changed_prediction_indices = self.changed_prediction_indices( + changed_prediction_indices = self.changed_prediction_indices_func( model, x_batch, x_perturbed ) From c44e2f1849e53f8fa91ea32bb5b594958ffe2b55 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 10:42:38 +0200 Subject: [PATCH 29/58] * code review comments --- quantus/functions/explanation_func.py | 1 + quantus/functions/loss_func.py | 2 +- quantus/functions/normalise_func.py | 4 +- quantus/functions/similarity_func.py | 2 +- quantus/helpers/__init__.py | 2 +- quantus/helpers/model/models.py | 1 + quantus/helpers/model/pytorch_model.py | 28 +++----- quantus/helpers/model/tf_model.py | 1 + quantus/helpers/plotting.py | 1 + quantus/helpers/utils.py | 3 + quantus/metrics/__init__.py | 2 +- quantus/metrics/axiomatic/completeness.py | 24 +++++-- quantus/metrics/axiomatic/input_invariance.py | 19 ++++-- quantus/metrics/axiomatic/non_sensitivity.py | 24 ++++--- quantus/metrics/base.py | 37 +++++----- quantus/metrics/base_batched.py | 1 - quantus/metrics/complexity/complexity.py | 17 +++-- .../complexity/effective_complexity.py | 16 +++-- quantus/metrics/complexity/sparseness.py | 16 +++-- .../faithfulness/faithfulness_correlation.py | 19 ++++-- .../faithfulness/faithfulness_estimate.py | 19 ++++-- quantus/metrics/faithfulness/infidelity.py | 20 ++++-- quantus/metrics/faithfulness/irof.py | 20 +++--- quantus/metrics/faithfulness/monotonicity.py | 20 +++--- .../faithfulness/monotonicity_correlation.py | 18 +++-- .../metrics/faithfulness/pixel_flipping.py | 21 +++--- .../faithfulness/region_perturbation.py | 20 +++--- quantus/metrics/faithfulness/road.py | 67 +++++-------------- quantus/metrics/faithfulness/selectivity.py | 19 ++++-- quantus/metrics/faithfulness/sensitivity_n.py | 19 ++++-- quantus/metrics/faithfulness/sufficiency.py | 15 +++-- .../localisation/attribution_localisation.py | 19 ++++-- quantus/metrics/localisation/auc.py | 21 ++++-- quantus/metrics/localisation/focus.py | 19 ++++-- quantus/metrics/localisation/pointing_game.py | 19 ++++-- .../localisation/relevance_mass_accuracy.py | 19 ++++-- .../localisation/relevance_rank_accuracy.py | 19 ++++-- .../localisation/top_k_intersection.py | 19 ++++-- .../model_parameter_randomisation.py | 24 +++---- quantus/metrics/randomisation/random_logit.py | 31 ++++++--- quantus/metrics/robustness/avg_sensitivity.py | 37 +++++----- quantus/metrics/robustness/consistency.py | 16 +++-- quantus/metrics/robustness/continuity.py | 20 ++++-- .../robustness/local_lipschitz_estimate.py | 28 +++++--- quantus/metrics/robustness/max_sensitivity.py | 21 ++++-- .../robustness/relative_input_stability.py | 24 ++++--- .../robustness/relative_output_stability.py | 25 ++++--- .../relative_representation_stability.py | 24 ++++--- 48 files changed, 521 insertions(+), 342 deletions(-) diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index fbc4de9ba..d34260f3f 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -385,6 +385,7 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": + num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index bd0723645..69181e1f3 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0) + return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) # If no need to normalise, return (a-b)^2. return np.average(((a - b) ** 2), axis=0) diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 5ad7f8145..2d518aad4 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate( # Check that square root of the second momment estimatte is nonzero. second_moment_sqrt = np.sqrt( - np.sum(a**2, axis=normalise_axes, keepdims=True) + np.sum(a ** 2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) if all(second_moment_sqrt != 0): a /= np.sqrt( - np.sum(a**2, axis=normalise_axes, keepdims=True) + np.sum(a ** 2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) else: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 93ffe8832..88d19a9a7 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -145,7 +145,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs, + **kwargs ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/__init__.py b/quantus/helpers/__init__.py index e4d00f57e..eafc68d63 100644 --- a/quantus/helpers/__init__.py +++ b/quantus/helpers/__init__.py @@ -8,4 +8,4 @@ # Import files dependent on package installations. __EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain") -__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") +__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") \ No newline at end of file diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 2feb32c55..38d0ca004 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,6 +12,7 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): + import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index b79442cb8..af383074a 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -86,11 +86,8 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info( - "Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", - named_module[0], - ) + logging.info("Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) break return linear_model @@ -121,10 +118,8 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info( - "Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." - ) + logging.info("Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -138,14 +133,12 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info( - "Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output." - ) + logging.info("Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output.") return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -370,6 +363,7 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: + """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index c72b06860..7a1769147 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -360,6 +360,7 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: + """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/plotting.py b/quantus/helpers/plotting.py index fd504556a..4fa305713 100644 --- a/quantus/helpers/plotting.py +++ b/quantus/helpers/plotting.py @@ -245,6 +245,7 @@ def plot_model_parameter_randomisation_experiment( plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method) else: + layers = list(results.keys()) scores = {k: [] for k in layers} # samples = len(results) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 0ec29da91..196bdb33f 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -764,6 +764,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: + # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " @@ -772,6 +773,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence ) ) elif x_subshapes.count(a_shape) > 1: + # Check that attribution dimensions are (unique) subdimensions of inputs. # Consider potentially expanded dims in attributions. @@ -781,6 +783,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1) ] if a_subshapes.count(a_shape) == 1: + # Inferring channel shape. for dim in range(len(np.shape(a_batch)[1:]) + 1): if a_shape == np.shape(a_batch)[1:][dim:]: diff --git a/quantus/metrics/__init__.py b/quantus/metrics/__init__.py index e6fafecfd..4e885eecd 100644 --- a/quantus/metrics/__init__.py +++ b/quantus/metrics/__init__.py @@ -4,9 +4,9 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +from quantus.metrics.axiomatic import * from quantus.metrics.base import Metric from quantus.metrics.base_perturbed import PerturbationMetric -from quantus.metrics.axiomatic import * from quantus.metrics.complexity import * from quantus.metrics.faithfulness import * from quantus.metrics.localisation import * diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 0bf17301c..0147c48e5 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -6,23 +6,31 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Completeness(Metric): """ Implementation of Completeness test by Sundararajan et al., 2017, also referred @@ -300,7 +308,7 @@ def evaluate_instance( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, @@ -321,8 +329,10 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. kwargs: - Unused kwargs. + Unused. Returns ------- diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index 06ce7f8bc..819e50de4 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -6,24 +6,31 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np -from quantus.helpers import warn -from quantus.helpers import asserts -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_shift, perturb_batch -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class InputInvariance(Metric): """ Implementation of Completeness test by Kindermans et al., 2017. diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 6436df4e7..699278699 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -6,24 +6,31 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import warn -from quantus.helpers import asserts -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +@final class NonSensitivity(Metric): """ Implementation of NonSensitivity by Nguyen at el., 2020. @@ -356,7 +363,7 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, @@ -377,7 +384,8 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - + args: + Unused. kwargs: Unused. diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 97dd663c6..c54bf4b88 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -7,36 +7,30 @@ from __future__ import annotations -from abc import abstractmethod -from typing import ( - Any, - Callable, - Dict, - Sequence, - ClassVar, - Generator, - Set, - TypedDict, - TypeVar, -) import logging +import sys +from abc import abstractmethod +from typing import Any, Callable, ClassVar, Dict, Generator, Sequence, Set, TypeVar import matplotlib.pyplot as plt import numpy as np from sklearn.utils import gen_batches from tqdm.auto import tqdm -from quantus.helpers import asserts -from quantus.helpers import utils -from quantus.helpers import warn +from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) from quantus.helpers.model.model_interface import ModelInterface +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + D = TypeVar("D", bound=Dict[str, Any]) log = logging.getLogger(__name__) @@ -316,6 +310,7 @@ def evaluate_batch( """ raise NotImplementedError() + @final def general_preprocess( self, model, @@ -626,6 +621,7 @@ def custom_postprocess( """ pass + @final def generate_batches( self, data: D, @@ -792,6 +788,7 @@ def all_results(self): ) return self.all_evaluation_scores + @final def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: """If `data_batch` has no `a_batch`, will compute explanations. This needs to be done on batch level to avoid OOM.""" @@ -806,10 +803,6 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: a_batch = self.explain_batch(model, x_batch, y_batch) data_batch["a_batch"] = a_batch - if hasattr(self, "a_axes") and self.a_axes is None: - # TODO: we must not modify global state during evaluation. - self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) - custom_batch = self.custom_batch_preprocess(data_batch) data_batch.update(custom_batch) return data_batch @@ -830,6 +823,7 @@ def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: """ return {} + @final def explain_batch( self, model: ModelInterface, @@ -837,7 +831,6 @@ def explain_batch( y_batch: np.ndarray, ) -> np.ndarray: """ - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index a4c692aa3..a4aca9634 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -9,7 +9,6 @@ from quantus.metrics.base import Metric - """Aliases to smoothen transition to uniform metric API.""" diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index e24feae6e..b48a9755a 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -5,22 +5,29 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np import scipy -from quantus.helpers import warn from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Complexity(Metric): """ Implementation of Complexity metric by Bhatt et al., 2020. diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 834a6a5c0..8fba4a71f 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -5,19 +5,25 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import warn from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final class EffectiveComplexity(Metric): diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index d939c1839..2107b1604 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -5,19 +5,25 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import warn from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final class Sparseness(Metric): diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 6e95db4f7..5fa79a349 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -5,27 +5,32 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import warn -from quantus.helpers import asserts -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class FaithfulnessCorrelation(Metric): """ Implementation of faithfulness correlation by Bhatt et al., 2020. diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 50f1d1e2d..f2b527d6e 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -5,27 +5,32 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import warn -from quantus.helpers import asserts -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class FaithfulnessEstimate(Metric): """ Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 8dee28faf..0c1492d58 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -5,26 +5,32 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np -from quantus.helpers import utils -from quantus.helpers import warn from quantus.functions.loss_func import mse -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Infidelity(Metric): """ Implementation of Infidelity by Yeh et al., 2019. diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 93e403d92..4928875c6 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -5,27 +5,31 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class IROF(Metric): """ Implementation of IROF (Iterative Removal of Features) by Rieger at el., 2020. diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 08492e7f1..43623b291 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -5,27 +5,31 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Monotonicity(Metric): """ Implementation of Monotonicity metric by Arya at el., 2019. diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 4a219d8bd..d51b6a334 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -6,26 +6,32 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import warn -from quantus.helpers import asserts -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_spearman -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class MonotonicityCorrelation(Metric): """ Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 4157d8e3a..b3f20d3d2 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -7,27 +7,32 @@ # Quantus project URL: . from __future__ import annotations + +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts -from quantus.helpers import plotting -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class PixelFlipping(Metric): """ Implementation of Pixel-Flipping experiment by Bach et al., 2015. diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 9fe653640..3429a9386 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -7,27 +7,31 @@ # Quantus project URL: . import itertools +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts -from quantus.helpers import plotting -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RegionPerturbation(Metric): """ Implementation of Region Perturbation by Samek et al., 2015. diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 1f860fda3..0f62e6774 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -6,24 +6,31 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import noisy_linear_imputation -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class ROAD(Metric): """ Implementation of ROAD evaluation strategy by Rong et al., 2022. @@ -300,6 +307,11 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance + def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" + if self.a_size is None: + self.a_size = data_batch["a_batch"][0, :, :].size + def custom_postprocess( self, model: ModelInterface, @@ -335,48 +347,3 @@ def custom_postprocess( percentage: np.mean(np.array(self.evaluation_scores)[:, p_ix]) for p_ix, percentage in enumerate(self.percentages) } - - def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: - """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" - data_batch = super().batch_preprocess(data_batch) - # Infer the size of attributions. - self.a_size = data_batch["a_batch"][0, :, :].size - return data_batch - - def evaluate_batch( - self, - *args, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ): - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - args: - Unused. - kwargs: - Unused. - - Returns - ------- - list - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 92a5712b8..150422cfc 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -7,26 +7,31 @@ # Quantus project URL: . import itertools +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import plotting -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices -from quantus.metrics.base import Metric +from quantus.helpers import plotting, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Selectivity(Metric): """ Implementation of Selectivity test by Montavon et al., 2018. diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 150b807d2..0fabb5798 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -6,27 +6,32 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional import numpy as np -from quantus.helpers import asserts -from quantus.helpers import plotting -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson -from quantus.metrics.base import Metric +from quantus.helpers import asserts, plotting, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class SensitivityN(Metric): """ Implementation of Sensitivity-N test by Ancona et al., 2019. diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index c784f1ba2..a250a83f1 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional, no_type_check import numpy as np from scipy.spatial.distance import cdist -from quantus.helpers import warn from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Sufficiency(Metric): """ Implementation of Sufficiency test by Dasgupta et al., 2022. diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 3ebfbe4fe..4fa256688 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class AttributionLocalisation(Metric): """ Implementation of the Attribution Localization by Kohlbrenner et al., 2020. diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index b565e8fae..99875cc06 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -6,23 +6,30 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from sklearn.metrics import roc_curve, auc +from sklearn.metrics import auc, roc_curve -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class AUC(Metric): """ Implementation of AUC metric by Fawcett et al., 2006. diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index f2ba5a4db..62abbe356 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional, no_type_check + import numpy as np -from quantus.helpers import plotting -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import plotting, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Focus(Metric): """ Implementation of Focus evaluation strategy by Arias et. al. 2022 diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 0fc204ace..71c437757 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class PointingGame(Metric): """ Implementation of the Pointing Game by Zhang et al., 2018. diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 5a309f77c..86d475051 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RelevanceMassAccuracy(Metric): """ Implementation of the Relevance Mass Accuracy by Arras et al., 2021. diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index b89f7bb61..f05e2b6bc 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RelevanceRankAccuracy(Metric): """ Implementation of the Relevance Rank Accuracy by Arras et al., 2021. diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index daa66cc8c..1a94acb8f 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -6,22 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class TopKIntersection(Metric): """ Implementation of the top-k intersection by Theiner et al., 2021. diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index d04211a15..0fb677b5a 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -6,33 +6,31 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Union, - Collection, -) +import sys +from typing import Any, Callable, Collection, Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm from quantus.functions.normalise_func import normalise_by_max from quantus.functions.similarity_func import correlation_spearman -from quantus.helpers import asserts -from quantus.helpers import warn +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +@final class ModelParameterRandomisation(Metric): """ Implementation of the Model Parameter Randomization Method by Adebayo et. al., 2018. diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index 79f97e544..ee174addb 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -6,23 +6,30 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.similarity_func import ssim -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +@final class RandomLogit(Metric): """ Implementation of the Random Logit Metric by Sixt et al., 2020. @@ -326,14 +333,16 @@ def custom_preprocess( def evaluate_batch( self, - *, + *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> List[float]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- @@ -345,8 +354,10 @@ def evaluate_batch( A np.ndarray which contains the output labels that are explained. a_batch: A np.ndarray which contains pre-computed attributions i.e., explanations. - _: - unused. + args: + Unused. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index ac3cfc138..0077cddb9 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -1,35 +1,42 @@ """This module contains the implementation of the Avg-Sensitivity metric.""" import functools - -# This file is part of Quantus. -# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. -# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. -# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . -# Quantus project URL: . - +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts from quantus.functions import norm_func -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.functions.similarity_func import difference -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( - make_perturb_func, make_changed_prediction_indices_func, + make_perturb_func, ) +from quantus.metrics.base import Metric + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class AvgSensitivity(Metric): """ Implementation of Avg-Sensitivity by Yeh at el., 2019. diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 22283b2ed..114711460 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -6,21 +6,29 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional, no_type_check + import numpy as np -from quantus.helpers import warn from quantus.functions.discretise_func import top_n_sign from quantus.functions.normalise_func import normalise_by_max -from quantus.metrics.base import Metric +from quantus.helpers import warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Consistency(Metric): """ Implementation of the Consistency metric which measures the expected local consistency, i.e., the probability diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 88d25db71..9ed59ebbe 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -6,26 +6,32 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . import itertools +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import utils -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import translation_x_direction from quantus.functions.similarity_func import lipschitz_constant -from quantus.metrics.base import Metric +from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class Continuity(Metric): """ Implementation of the Continuity test by Montavon et al., 2018. diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index d6d0d749d..aa21c1024 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -6,29 +6,35 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import gaussian_noise, perturb_batch -from quantus.functions.similarity_func import lipschitz_constant, distance_euclidean -from quantus.metrics.base import Metric +from quantus.functions.similarity_func import distance_euclidean, lipschitz_constant +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) - +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( - make_perturb_func, make_changed_prediction_indices_func, + make_perturb_func, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class LocalLipschitzEstimate(Metric): """ Implementation of the Local Lipschitz Estimate (or Stability) test by Alvarez-Melis et al., 2018a, 2018b. @@ -303,7 +309,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. @@ -318,6 +324,8 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 721f3b58b..61d7bab87 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -6,29 +6,36 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import sys from typing import Any, Callable, Dict, List, Optional + import numpy as np -from quantus.helpers import asserts from quantus.functions import norm_func -from quantus.helpers import warn -from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.functions.similarity_func import difference -from quantus.metrics.base import Metric +from quantus.helpers import asserts, warn from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( make_changed_prediction_indices_func, make_perturb_func, ) +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class MaxSensitivity(Metric): """ Implementation of Max-Sensitivity by Yeh at el., 2019. diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 70faf507d..0e5576833 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -6,31 +6,39 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + import numpy as np if TYPE_CHECKING: import tensorflow as tf import torch +import sys -from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base import Metric -from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate -from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( - make_perturb_func, make_changed_prediction_indices_func, + make_perturb_func, ) +from quantus.helpers.warn import warn_parameterisation +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RelativeInputStability(Metric): """ Relative Input Stability leverages the stability of an explanation with respect to the change in the input data. diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 2f3729b61..b1e390806 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -6,30 +6,39 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + import numpy as np if TYPE_CHECKING: import tensorflow as tf import torch -from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base import Metric -from quantus.helpers.warn import warn_parameterisation +import sys + from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate -from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( - make_perturb_func, make_changed_prediction_indices_func, + make_perturb_func, ) +from quantus.helpers.warn import warn_parameterisation +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RelativeOutputStability(Metric): """ Relative Output Stability leverages the stability of an explanation with respect to the change in the output logits. diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index 3a02a2a59..40d1db44a 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -6,31 +6,39 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + import numpy as np if TYPE_CHECKING: import tensorflow as tf import torch +import sys -from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.base import Metric -from quantus.helpers.warn import warn_parameterisation from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate -from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.helpers.enums import ( - ModelType, DataType, - ScoreDirection, EvaluationCategory, + ModelType, + ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import ( - make_perturb_func, make_changed_prediction_indices_func, + make_perturb_func, ) +from quantus.helpers.warn import warn_parameterisation +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final +@final class RelativeRepresentationStability(Metric): """ Relative Representation Stability leverages the stability of an explanation with respect From bcb6f8e0db39fabc790f5b0eb4e8d60c27b93858 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 11:09:06 +0200 Subject: [PATCH 30/58] * test fixes --- quantus/helpers/perturbation_utils.py | 11 +++--- quantus/metrics/__init__.py | 1 - quantus/metrics/axiomatic/non_sensitivity.py | 2 - quantus/metrics/base.py | 20 +++++++--- quantus/metrics/faithfulness/infidelity.py | 8 +--- .../faithfulness/region_perturbation.py | 9 ++++- quantus/metrics/faithfulness/road.py | 38 +++++++++++++++++++ quantus/metrics/robustness/avg_sensitivity.py | 2 - quantus/metrics/robustness/continuity.py | 7 ++++ .../robustness/local_lipschitz_estimate.py | 4 +- quantus/metrics/robustness/max_sensitivity.py | 2 - .../robustness/relative_input_stability.py | 2 - .../robustness/relative_output_stability.py | 2 - 13 files changed, 76 insertions(+), 32 deletions(-) diff --git a/quantus/helpers/perturbation_utils.py b/quantus/helpers/perturbation_utils.py index 841633d21..a17b2c3a2 100644 --- a/quantus/helpers/perturbation_utils.py +++ b/quantus/helpers/perturbation_utils.py @@ -23,12 +23,13 @@ def make_perturb_func( perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs ) -> PerturbFunc | functools.partial: """A utility function to save few lines of code during perturbation metric initialization.""" - if perturb_func_kwargs is None: - perturb_func_kwargs = {} + if perturb_func_kwargs is not None: + func_kwargs = kwargs.copy() + func_kwargs.update(perturb_func_kwargs) + else: + func_kwargs = kwargs - kwargs.update(perturb_func_kwargs) - - return functools.partial(perturb_func, **perturb_func_kwargs) + return functools.partial(perturb_func, **func_kwargs) def make_changed_prediction_indices_func( diff --git a/quantus/metrics/__init__.py b/quantus/metrics/__init__.py index 4e885eecd..f8003d017 100644 --- a/quantus/metrics/__init__.py +++ b/quantus/metrics/__init__.py @@ -6,7 +6,6 @@ from quantus.metrics.axiomatic import * from quantus.metrics.base import Metric -from quantus.metrics.base_perturbed import PerturbationMetric from quantus.metrics.complexity import * from quantus.metrics.faithfulness import * from quantus.metrics.localisation import * diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 699278699..b8f79ccb3 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -127,8 +127,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index c54bf4b88..76e9ca66d 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -27,9 +27,14 @@ from quantus.helpers.model.model_interface import ModelInterface if sys.version_info >= (3, 8): - from typing import final + from typing import final, LiteralString else: - from typing_extensions import final + from typing_extensions import final, LiteralString + +if sys.version_info >= (3, 10): + from typing import LiteralString +else: + from typing_extensions import LiteralString D = TypeVar("D", bound=Dict[str, Any]) log = logging.getLogger(__name__) @@ -40,7 +45,7 @@ class Metric: Interface defining Metrics' API. """ - name: ClassVar[str] + name: ClassVar[LiteralString] data_applicability: ClassVar[Set[DataType]] model_applicability: ClassVar[Set[ModelType]] score_direction: ClassVar[ScoreDirection] @@ -804,10 +809,13 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: data_batch["a_batch"] = a_batch custom_batch = self.custom_batch_preprocess(data_batch) - data_batch.update(custom_batch) + if custom_batch is not None: + data_batch.update(custom_batch) return data_batch - def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + def custom_batch_preprocess( + self, data_batch: Dict[str, ...] + ) -> Dict[str, ...] | None: """ Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions @@ -821,7 +829,7 @@ def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: ------- """ - return {} + pass @final def explain_batch( diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 0c1492d58..e080843d4 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -73,11 +73,9 @@ def __init__( n_perturb_samples: int = 10, abs: bool = False, normalise: bool = False, - normalise_func: Optional[ - Callable[[np.ndarray], np.ndarray] - ] = baseline_replacement_by_indices, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = baseline_replacement_by_indices, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -136,8 +134,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 3429a9386..9c8741aaf 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -135,8 +135,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, @@ -457,3 +455,10 @@ def evaluate_batch( self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] + + def custom_batch_preprocess(self, data_batch: Dict[str, ...]): + """RegionPerturbation requires `a_axes` property to be set before evaluation.""" + if self.a_axes is None: + self.a_axes = utils.infer_attribution_axes( + data_batch["a_batch"], data_batch["x_batch"] + ) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 0f62e6774..11b9a8a44 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -347,3 +347,41 @@ def custom_postprocess( percentage: np.mean(np.array(self.evaluation_scores)[:, p_ix]) for p_ix, percentage in enumerate(self.percentages) } + + def evaluate_batch( + self, + *args, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[List[float]]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + args: + Unused. + kwargs: + Unused. + + Returns + ------- + list + The evaluation results. + """ + return [ + self.evaluate_instance(model=model, x=x, y=y, a=a) + for x, y, a in zip(x_batch, y_batch, a_batch) + ] diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 0077cddb9..fba0f2231 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -139,8 +139,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 9ed59ebbe..253dfa522 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -469,3 +469,10 @@ def evaluate_batch( self.evaluate_instance(model=model, x=x, y=y) for x, y in zip(x_batch, y_batch) ] + + def custom_batch_preprocess(self, data_batch: Dict[str, ...]): + """Continuity requires `a_axes` property to be set before evaluation.""" + if self.a_axes is None: + self.a_axes = utils.infer_attribution_axes( + data_batch["a_batch"], data_batch["x_batch"] + ) diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index aa21c1024..adf8627e5 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -73,9 +73,9 @@ def __init__( nr_samples: int = 200, abs: bool = False, normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = gaussian_noise, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = None, + perturb_func: Callable = gaussian_noise, perturb_mean: float = 0.0, perturb_std: float = 0.1, perturb_func_kwargs: Optional[Dict[str, Any]] = None, diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 61d7bab87..8bf354e36 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -138,8 +138,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 0e5576833..e0e82c842 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -123,8 +123,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index b1e390806..fcf3d8eb1 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -125,8 +125,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, From e263a4eb8447da1977e5f9cc6026adda049a006b Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 14:28:25 +0200 Subject: [PATCH 31/58] * test fixes --- pytest.ini | 3 + quantus/metrics/base.py | 111 ++++++++++-------- .../faithfulness/monotonicity_correlation.py | 2 - .../faithfulness/region_perturbation.py | 7 -- quantus/metrics/faithfulness/road.py | 10 +- quantus/metrics/faithfulness/sufficiency.py | 62 ++++------ .../localisation/attribution_localisation.py | 3 - quantus/metrics/robustness/continuity.py | 7 -- tests/conftest.py | 18 ++- tox.ini | 2 +- 10 files changed, 108 insertions(+), 117 deletions(-) diff --git a/pytest.ini b/pytest.ini index 00a92ec87..1a32be04c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -22,3 +22,6 @@ filterwarnings = error ignore::UserWarning ignore::DeprecationWarning + +# Don't suppress logs. +addopts = -s -v diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 76e9ca66d..2af892561 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging +import os import sys from abc import abstractmethod from typing import Any, Callable, ClassVar, Dict, Generator, Sequence, Set, TypeVar @@ -27,9 +28,9 @@ from quantus.helpers.model.model_interface import ModelInterface if sys.version_info >= (3, 8): - from typing import final, LiteralString + from typing import LiteralString, final else: - from typing_extensions import final, LiteralString + from typing_extensions import LiteralString, final if sys.version_info >= (3, 10): from typing import LiteralString @@ -45,12 +46,19 @@ class Metric: Interface defining Metrics' API. """ + # Class attributes. name: ClassVar[LiteralString] data_applicability: ClassVar[Set[DataType]] model_applicability: ClassVar[Set[ModelType]] score_direction: ClassVar[ScoreDirection] # can one metric fall into multiple categories? evaluation_category: ClassVar[EvaluationCategory] + # Instance attributes. + explain_func: Callable + explain_func_kwargs: Dict[str, ...] + a_axes: Sequence[int] + evaluation_scores: Any + all_evaluation_scores: Any def __init__( self, @@ -113,19 +121,17 @@ def __init__( self.return_aggregate = return_aggregate self.aggregate_func = aggregate_func self.normalise_func = normalise_func - - if normalise_func_kwargs is None: - normalise_func_kwargs = {} - self.normalise_func_kwargs = normalise_func_kwargs + self.normalise_func_kwargs = normalise_func_kwargs or {} self.default_plot_func = default_plot_func - self.disable_warnings = disable_warnings - self.display_progressbar = display_progressbar + # We need underscores here to avoid conflict with @property descriptor. + self._disable_warnings = disable_warnings + self._display_progressbar = display_progressbar - self.a_axes: Sequence[int] = None + self.a_axes = None - self.evaluation_scores: Any = [] - self.all_evaluation_scores: Any = [] + self.evaluation_scores = [] + self.all_evaluation_scores = [] def __call__( self, @@ -388,7 +394,6 @@ def general_preprocess( channel_first = utils.infer_channel_first(x_batch) x_batch = utils.make_channel_first(x_batch, channel_first) - # TODO: can model be None? if model is not None: # Use attribute value if not passed explicitly. model = utils.get_wrapped_model( @@ -401,9 +406,7 @@ def general_preprocess( # Save as attribute, some metrics need it during processing. self.explain_func = explain_func - if explain_func_kwargs is None: - explain_func_kwargs = {} - self.explain_func_kwargs = explain_func_kwargs + self.explain_func_kwargs = explain_func_kwargs or {} # Include device in explain_func_kwargs. if device is not None and "device" not in self.explain_func_kwargs: @@ -688,23 +691,16 @@ def generate_batches( n_batches = np.ceil(n_instances / batch_size) - # Create iterator for batch index. - iterator = tqdm( - total=n_batches, - disable=not self.display_progressbar, - ) - - # Iterate over batch index - for batch_idx in gen_batches(n_instances, batch_size): - # Calculate instance index for start and end of batch. - # Create batch dictionary with all specified batch instance values - batch = { - key: value[batch_idx.start : batch_idx.stop] - for key, value in batched_value_kwargs.items() - } - # Yield batch dictionary including single value keyword arguments. - yield {**batch, **single_value_kwargs} - iterator.update(min(batch_size, batch_idx.stop - batch_idx.start)) + with tqdm(total=n_batches, disable=not self.display_progressbar) as pbar: + for batch_idx in gen_batches(n_instances, batch_size): + batch = { + key: value[batch_idx.start : batch_idx.stop] + for key, value in batched_value_kwargs.items() + } + # Yield batch dictionary including single value keyword arguments. + yield {**batch, **single_value_kwargs} + # Update progressbar by number of samples in this batch. + pbar.update(batch_idx.stop - batch_idx.start) def plot( self, @@ -777,25 +773,13 @@ def get_params(self) -> Dict[str, Any]: ] return {k: v for k, v in self.__dict__.items() if k not in attr_exclude} - @property - def last_results(self): - log.warning( - "Warning: 'last_results' has been renamed to 'evaluation_scores'. " - "'last_results' is removed in current version." - ) - return self.evaluation_scores - - @property - def all_results(self): - log.warning( - "Warning: 'all_results' has been renamed to 'all_evaluation_scores'. " - "'all_results' is removed in current version." - ) - return self.all_evaluation_scores - @final def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: - """If `data_batch` has no `a_batch`, will compute explanations. This needs to be done on batch level to avoid OOM.""" + """ + If `data_batch` has no `a_batch`, will compute explanations. + This needs to be done on batch level to avoid OOM. Additionally will set `a_axes` property if it is None, + this can be done earliest after we have first `a_batch`. + """ x_batch = data_batch["x_batch"] @@ -808,6 +792,9 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: a_batch = self.explain_batch(model, x_batch, y_batch) data_batch["a_batch"] = a_batch + if self.a_axes is None: + self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) + custom_batch = self.custom_batch_preprocess(data_batch) if custom_batch is not None: data_batch.update(custom_batch) @@ -884,3 +871,29 @@ def explain_batch( a_batch = np.abs(a_batch) return a_batch + + @property + def display_progressbar(self) -> bool: + """A helper to avoid polluting test outputs with tqdm progress bars.""" + return ( + self._display_progressbar + and + # Don't show progress bar in github actions. + os.environ.get("GITHUB_ACTIONS") != "true" + and + # Don't show progress bar when running unit tests. + "PYTEST" not in os.environ + ) + + @property + def disable_warnings(self) -> bool: + """A helper to avoid polluting test outputs with warnings.""" + return ( + not self._disable_warnings + and + # Don't show progress bar in github actions. + os.environ.get("GITHUB_ACTIONS") != "true" + and + # Don't show progress bar when running unit tests. + "PYTEST" not in os.environ + ) diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index d51b6a334..10bf40578 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -132,8 +132,6 @@ def __init__( normalise=normalise, normalise_func=normalise_func, normalise_func_kwargs=normalise_func_kwargs, - perturb_func=perturb_func, - perturb_func_kwargs=perturb_func_kwargs, return_aggregate=return_aggregate, aggregate_func=aggregate_func, default_plot_func=default_plot_func, diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 9c8741aaf..edf0447da 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -455,10 +455,3 @@ def evaluate_batch( self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) ] - - def custom_batch_preprocess(self, data_batch: Dict[str, ...]): - """RegionPerturbation requires `a_axes` property to be set before evaluation.""" - if self.a_axes is None: - self.a_axes = utils.infer_attribution_axes( - data_batch["a_batch"], data_batch["x_batch"] - ) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 11b9a8a44..448998eb6 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -307,11 +307,6 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance - def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: - """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" - if self.a_size is None: - self.a_size = data_batch["a_batch"][0, :, :].size - def custom_postprocess( self, model: ModelInterface, @@ -348,6 +343,11 @@ def custom_postprocess( for p_ix, percentage in enumerate(self.percentages) } + def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" + if self.a_size is None: + self.a_size = data_batch["a_batch"][0, :, :].size + def evaluate_batch( self, *args, diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index a250a83f1..d51a8754e 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -285,43 +285,6 @@ def evaluate_instance( return 0 return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) - def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: - """ - - - Parameters - ---------- - data_batch - - Returns - ------- - - """ - data_batch = super().batch_preprocess(data_batch) - model = data_batch["model"] - x_batch = data_batch["x_batch"] - a_batch = data_batch["a_batch"] - a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) - dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) - dist_matrix = self.normalise_func(dist_matrix) - a_sim_matrix = np.zeros_like(dist_matrix) - a_sim_matrix[dist_matrix <= self.threshold] = 1 - - # Predict on input. - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) - y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() - - custom_batch = { - "i_batch": np.arange(x_batch.shape[0]), - "a_sim_vector_batch": a_sim_matrix, - "y_pred_classes": y_pred_classes, - } - - data_batch.update(custom_batch) - return data_batch - def evaluate_batch( self, *args, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs ) -> List[float]: @@ -355,3 +318,28 @@ def evaluate_batch( ) for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) ] + + def custom_batch_preprocess( + self, data_batch: Dict[str, ...] + ) -> Dict[str, np.ndarray]: + model = data_batch["model"] + a_batch = data_batch["a_batch"] + x_batch = data_batch["x_batch"] + + a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) + dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) + dist_matrix = self.normalise_func(dist_matrix) + a_sim_matrix = np.zeros_like(dist_matrix) + a_sim_matrix[dist_matrix <= self.threshold] = 1 + + # Predict on input. + x_input = model.shape_input( + x_batch, x_batch[0].shape, channel_first=True, batched=True + ) + y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() + + return { + "i_batch": np.arange(x_batch.shape[0]), + "a_sim_vector_batch": a_sim_matrix, + "y_pred_classes": y_pred_classes, + } diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 4fa256688..46360f9b4 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -122,9 +122,6 @@ def __init__( # Save metric-specific attributes. self.weighted = weighted self.max_size = max_size - - # Asserts and warnings. - self.disable_warnings = disable_warnings if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 253dfa522..9ed59ebbe 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -469,10 +469,3 @@ def evaluate_batch( self.evaluate_instance(model=model, x=x, y=y) for x, y in zip(x_batch, y_batch) ] - - def custom_batch_preprocess(self, data_batch: Dict[str, ...]): - """Continuity requires `a_axes` property to be set before evaluation.""" - if self.a_axes is None: - self.a_axes = utils.infer_attribution_axes( - data_batch["a_batch"], data_batch["x_batch"] - ) diff --git a/tests/conftest.py b/tests/conftest.py index 3960fe74e..c81be0121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,10 @@ import pickle import torch import numpy as np -import tensorflow -from tensorflow.keras.datasets import cifar10 +from keras.datasets import cifar10 import pandas as pd from sklearn.model_selection import train_test_split +import os from quantus.helpers.model.models import ( LeNet, @@ -211,10 +211,10 @@ def titanic_dataset(): @pytest.fixture(scope="session", autouse=True) def load_mnist_model_softmax_not_last(): - ''' + """ Model with a softmax layer not last in the list of modules. Used to test the logic of pytorch_model.py, method get_softmax_arg_model (see the method's documentation). - ''' + """ model = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Softmax(), @@ -225,12 +225,18 @@ def load_mnist_model_softmax_not_last(): @pytest.fixture(scope="session", autouse=True) def load_mnist_model_softmax(): - ''' + """ Model with a softmax layer last in the list of modules. Used to test the logic of pytorch_model.py, method get_softmax_arg_model (see the method's documentation). - ''' + """ model = torch.nn.Sequential( LeNet(), torch.nn.Softmax(), ) return model + + +@pytest.fixture(scope="session", autouse=True) +def set_env(): + """Set ENV var, so test outputs are not polluted by progress bars and warnings.""" + os.environ["PYTEST"] = "1" diff --git a/tox.ini b/tox.ini index 745a8d6f9..c57927bd5 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ deps = pass_env = TF_XLA_FLAGS commands = - pytest -s -v {posargs} + pytest {posargs} [testenv:coverage] description = Run the tests with coverage From ccaab33c5e02f4750680d5c628f011a7ec489757 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 14:57:38 +0200 Subject: [PATCH 32/58] * test fixes --- quantus/__init__.py | 2 +- quantus/evaluation.py | 12 +++-- quantus/metrics/axiomatic/completeness.py | 2 +- quantus/metrics/axiomatic/input_invariance.py | 17 +----- quantus/metrics/axiomatic/non_sensitivity.py | 2 +- quantus/metrics/base.py | 28 +++++++--- quantus/metrics/base_batched.py | 4 +- quantus/metrics/complexity/complexity.py | 2 +- .../complexity/effective_complexity.py | 3 +- quantus/metrics/complexity/sparseness.py | 3 +- .../faithfulness/faithfulness_correlation.py | 2 +- .../faithfulness/faithfulness_estimate.py | 2 +- quantus/metrics/faithfulness/infidelity.py | 2 +- quantus/metrics/faithfulness/irof.py | 2 +- quantus/metrics/faithfulness/monotonicity.py | 2 +- .../faithfulness/monotonicity_correlation.py | 2 +- .../metrics/faithfulness/pixel_flipping.py | 5 +- .../faithfulness/region_perturbation.py | 2 +- quantus/metrics/faithfulness/road.py | 12 ++--- quantus/metrics/faithfulness/selectivity.py | 2 +- quantus/metrics/faithfulness/sensitivity_n.py | 2 +- quantus/metrics/faithfulness/sufficiency.py | 53 ++++++++++--------- .../localisation/attribution_localisation.py | 2 +- quantus/metrics/localisation/auc.py | 2 +- quantus/metrics/robustness/consistency.py | 22 ++------ quantus/metrics/robustness/max_sensitivity.py | 4 +- 26 files changed, 94 insertions(+), 99 deletions(-) diff --git a/quantus/__init__.py b/quantus/__init__.py index 485b9d851..c29dc0a35 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -26,4 +26,4 @@ from quantus.helpers.model import * # Expose the helpers utils. -from quantus.helpers.utils import * +from quantus.helpers.utils import * \ No newline at end of file diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 8c37a24a7..75d628ee3 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {"call_kwargs_empty": {}} + call_kwargs = {'call_kwargs_empty': {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -92,9 +92,11 @@ def evaluate( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." for method, value in xai_methods.items(): + results[method] = {} if callable(value): + explain_funcs[method] = value explain_func = value @@ -114,6 +116,7 @@ def evaluate( asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch) elif isinstance(value, Dict): + if explain_func_kwargs is not None: warnings.warn( "Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods." @@ -137,6 +140,7 @@ def evaluate( a_batch = value else: + raise TypeError( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." ) @@ -144,10 +148,12 @@ def evaluate( if explain_func_kwargs is None: explain_func_kwargs = {} - for metric, metric_func in metrics.items(): + for (metric, metric_func) in metrics.items(): + results[method][metric] = {} - for call_kwarg_str, call_kwarg in call_kwargs.items(): + for (call_kwarg_str, call_kwarg) in call_kwargs.items(): + if progress: print( f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..." diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 0147c48e5..31c9becf5 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -31,7 +31,7 @@ @final -class Completeness(Metric): +class Completeness(Metric[List[float]]): """ Implementation of Completeness test by Sundararajan et al., 2017, also referred to as Summation to Delta by Shrikumar et al., 2017 and Conservation by diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index 819e50de4..19ac2b579 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -31,7 +31,7 @@ @final -class InputInvariance(Metric): +class InputInvariance(Metric[List[float]]): """ Implementation of Completeness test by Kindermans et al., 2017. @@ -312,18 +312,5 @@ def evaluate_batch( return score def custom_preprocess(self, *args, **kwargs) -> None: - """ - Additional explain_func assert, as the one in prepare() won't be executed when a_batch != None. - - Parameters - ---------- - args: - Unused. - kwargs: - Unused. - - Returns - ------- - None - """ + """Additional explain_func assert, as the one in prepare() won't be executed when `a_batch != None.`""" asserts.assert_explain_func(explain_func=self.explain_func) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index b8f79ccb3..ed5918258 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -31,7 +31,7 @@ @final -class NonSensitivity(Metric): +class NonSensitivity(Metric[List[float]]): """ Implementation of NonSensitivity by Nguyen at el., 2020. diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 64a5560bc..6ce5deac6 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -11,7 +11,17 @@ import os import sys from abc import abstractmethod -from typing import Any, Callable, ClassVar, Dict, Generator, Sequence, Set, TypeVar +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generator, + Sequence, + Set, + TypeVar, + Generic, +) import matplotlib.pyplot as plt import numpy as np @@ -32,16 +42,18 @@ else: from typing_extensions import LiteralString, final -if sys.version_info >= (3, 10): +if sys.version_info >= (3, 11): from typing import LiteralString else: from typing_extensions import LiteralString D = TypeVar("D", bound=Dict[str, Any]) +# Return value of __call__ +R = TypeVar("R") log = logging.getLogger(__name__) -class Metric: +class Metric(Generic[R]): """ Interface defining Metrics' API. """ @@ -149,7 +161,7 @@ def __call__( batch_size: int = 64, custom_batch: Any = None, **kwargs, - ): + ) -> R: """ This implementation represents the main logic of the metric and makes the class object callable. It completes batch-wise evaluation of explanations (a_batch) with respect to input data (x_batch), @@ -293,7 +305,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, + s_batch: np.ndarray | None, **kwargs, ): """ @@ -457,8 +469,8 @@ def custom_preprocess( x_batch: np.ndarray, y_batch: np.ndarray | None, a_batch: np.ndarray | None, - s_batch: np.ndarray, - custom_batch: np.ndarray | None, + s_batch: np.ndarray | None, + custom_batch: Any, ) -> Dict[str, ...] | None: """ Implement this method if you need custom preprocessing of data, @@ -600,7 +612,7 @@ def custom_postprocess( x_batch: np.ndarray, y_batch: np.ndarray | None, a_batch: np.ndarray | None, - s_batch: np.ndarray, + s_batch: np.ndarray | None, **kwargs, ): """ diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index a4aca9634..0aaa24f75 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -5,9 +5,9 @@ # Quantus project URL: . import abc -import warnings from quantus.metrics.base import Metric +import logging """Aliases to smoothen transition to uniform metric API.""" @@ -17,7 +17,7 @@ class BatchedMetric(Metric, abc.ABC): """Alias to quantus.Metric, will be removed in next major release.""" def __new__(cls, *args, **kwargs): - warnings.warn( + logging.warning( "BatchedMetric was deprecated, since it is just an alias to Metric. Please subclass Metric directly." ) super().__new__(*args, **kwargs) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index b48a9755a..841e2027e 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -28,7 +28,7 @@ @final -class Complexity(Metric): +class Complexity(Metric[List[float]]): """ Implementation of Complexity metric by Bhatt et al., 2020. diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 8fba4a71f..be3378a3f 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -26,7 +26,8 @@ from typing_extensions import final -class EffectiveComplexity(Metric): +@final +class EffectiveComplexity(Metric[List[float]]): """ Implementation of Effective complexity metric by Nguyen at el., 2020. diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 2107b1604..f63d33699 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -26,7 +26,8 @@ from typing_extensions import final -class Sparseness(Metric): +@final +class Sparseness(Metric[List[float]]): """ Implementation of Sparseness metric by Chalasani et al., 2020. diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 5fa79a349..0bc3dd1c2 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -31,7 +31,7 @@ @final -class FaithfulnessCorrelation(Metric): +class FaithfulnessCorrelation(Metric[List[float]]): """ Implementation of faithfulness correlation by Bhatt et al., 2020. diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index f2b527d6e..2dd317f81 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -31,7 +31,7 @@ @final -class FaithfulnessEstimate(Metric): +class FaithfulnessEstimate(Metric[List[float]]): """ Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index e080843d4..660778f21 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -31,7 +31,7 @@ @final -class Infidelity(Metric): +class Infidelity(Metric[List[float]]): """ Implementation of Infidelity by Yeh et al., 2019. diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 4928875c6..4ebdee2de 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -30,7 +30,7 @@ @final -class IROF(Metric): +class IROF(Metric[List[float]]): """ Implementation of IROF (Iterative Removal of Features) by Rieger at el., 2020. diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 43623b291..dc4d61a7a 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -30,7 +30,7 @@ @final -class Monotonicity(Metric): +class Monotonicity(Metric[List[float]]): """ Implementation of Monotonicity metric by Arya at el., 2019. diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 10bf40578..69dd9bf13 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -32,7 +32,7 @@ @final -class MonotonicityCorrelation(Metric): +class MonotonicityCorrelation(Metric[List[float]]): """ Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index b3f20d3d2..04eaf4ff2 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -5,9 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - -from __future__ import annotations - import sys from typing import Any, Callable, Dict, List, Optional @@ -33,7 +30,7 @@ @final -class PixelFlipping(Metric): +class PixelFlipping(Metric[List[float]]): """ Implementation of Pixel-Flipping experiment by Bach et al., 2015. diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index edf0447da..e69d977f4 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -32,7 +32,7 @@ @final -class RegionPerturbation(Metric): +class RegionPerturbation(Metric[List[float]]): """ Implementation of Region Perturbation by Samek et al., 2015. diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 448998eb6..b0077e190 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -31,7 +31,7 @@ @final -class ROAD(Metric): +class ROAD(Metric[List[float]]): """ Implementation of ROAD evaluation strategy by Rong et al., 2022. @@ -307,6 +307,11 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance + def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" + if self.a_size is None: + self.a_size = data_batch["a_batch"][0, :, :].size + def custom_postprocess( self, model: ModelInterface, @@ -343,11 +348,6 @@ def custom_postprocess( for p_ix, percentage in enumerate(self.percentages) } - def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: - """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" - if self.a_size is None: - self.a_size = data_batch["a_batch"][0, :, :].size - def evaluate_batch( self, *args, diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 150422cfc..99077518e 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -32,7 +32,7 @@ @final -class Selectivity(Metric): +class Selectivity(Metric[List[float]]): """ Implementation of Selectivity test by Montavon et al., 2018. diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 0fabb5798..fbfcfe1f5 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -32,7 +32,7 @@ @final -class SensitivityN(Metric): +class SensitivityN(Metric[List[float]]): """ Implementation of Sensitivity-N test by Ancona et al., 2019. diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index d51a8754e..badcea75b 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -29,7 +29,7 @@ @final -class Sufficiency(Metric): +class Sufficiency(Metric[List[float]]): """ Implementation of Sufficiency test by Dasgupta et al., 2022. @@ -285,6 +285,32 @@ def evaluate_instance( return 0 return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) + def custom_batch_preprocess( + self, data_batch: Dict[str, ...] + ) -> Dict[str, np.ndarray]: + """Compute additional arguments required for Sufficiency evaluation on batch-level.""" + model = data_batch["model"] + a_batch = data_batch["a_batch"] + x_batch = data_batch["x_batch"] + + a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) + dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) + dist_matrix = self.normalise_func(dist_matrix) + a_sim_matrix = np.zeros_like(dist_matrix) + a_sim_matrix[dist_matrix <= self.threshold] = 1 + + # Predict on input. + x_input = model.shape_input( + x_batch, x_batch[0].shape, channel_first=True, batched=True + ) + y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() + + return { + "i_batch": np.arange(x_batch.shape[0]), + "a_sim_vector_batch": a_sim_matrix, + "y_pred_classes": y_pred_classes, + } + def evaluate_batch( self, *args, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs ) -> List[float]: @@ -318,28 +344,3 @@ def evaluate_batch( ) for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) ] - - def custom_batch_preprocess( - self, data_batch: Dict[str, ...] - ) -> Dict[str, np.ndarray]: - model = data_batch["model"] - a_batch = data_batch["a_batch"] - x_batch = data_batch["x_batch"] - - a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) - dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) - dist_matrix = self.normalise_func(dist_matrix) - a_sim_matrix = np.zeros_like(dist_matrix) - a_sim_matrix[dist_matrix <= self.threshold] = 1 - - # Predict on input. - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) - y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() - - return { - "i_batch": np.arange(x_batch.shape[0]), - "a_sim_vector_batch": a_sim_matrix, - "y_pred_classes": y_pred_classes, - } diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 46360f9b4..61f650dac 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -29,7 +29,7 @@ @final -class AttributionLocalisation(Metric): +class AttributionLocalisation(Metric[List[float]]): """ Implementation of the Attribution Localization by Kohlbrenner et al., 2020. diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 99875cc06..5901f1aed 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -30,7 +30,7 @@ @final -class AUC(Metric): +class AUC(Metric[List[float]]): """ Implementation of AUC metric by Fawcett et al., 2006. diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 114711460..f4d4aaa4f 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -276,19 +276,10 @@ def evaluate_instance( return 0 return np.sum(pred_same_a == pred_a) / len(diff_a) - def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: - """ - - Parameters - ---------- - data_batch - - Returns - ------- - - """ - data_batch = super().batch_preprocess(data_batch) - + def custom_batch_preprocess( + self, data_batch: Dict[str, ...] + ) -> Dict[str, ...] | None: + """Compute additional arguments required for Consistency on batch-level.""" model = data_batch["model"] x_batch = data_batch["x_batch"] x_input = model.shape_input( @@ -302,15 +293,12 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() - custom_batch = { + return { "i_batch": np.arange(x_batch.shape[0]), "a_label_batch": a_labels, "y_pred_classes": y_pred_classes, } - data_batch.update(custom_batch) - return data_batch - @no_type_check def evaluate_batch( self, diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 8bf354e36..7d7311045 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -303,7 +303,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - **_, + **kwargs, ) -> np.ndarray: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. @@ -318,6 +318,8 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. + kwargs: + Unused. Returns ------- From c2f368ddb863580b7ec74ed6affbe9c5d972375c Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 15:04:46 +0200 Subject: [PATCH 33/58] * cleanup --- quantus/metrics/base.py | 3 ++- quantus/metrics/localisation/focus.py | 2 +- quantus/metrics/localisation/pointing_game.py | 2 +- quantus/metrics/localisation/relevance_mass_accuracy.py | 2 +- quantus/metrics/localisation/relevance_rank_accuracy.py | 2 +- quantus/metrics/localisation/top_k_intersection.py | 2 +- quantus/metrics/randomisation/random_logit.py | 2 +- quantus/metrics/robustness/avg_sensitivity.py | 2 +- quantus/metrics/robustness/consistency.py | 2 +- quantus/metrics/robustness/continuity.py | 2 +- quantus/metrics/robustness/local_lipschitz_estimate.py | 2 +- quantus/metrics/robustness/max_sensitivity.py | 2 +- quantus/metrics/robustness/relative_input_stability.py | 2 +- quantus/metrics/robustness/relative_output_stability.py | 2 +- .../metrics/robustness/relative_representation_stability.py | 2 +- 15 files changed, 16 insertions(+), 15 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 6ce5deac6..259eff944 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -48,9 +48,10 @@ from typing_extensions import LiteralString D = TypeVar("D", bound=Dict[str, Any]) +log = logging.getLogger(__name__) + # Return value of __call__ R = TypeVar("R") -log = logging.getLogger(__name__) class Metric(Generic[R]): diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 62abbe356..974ae1af5 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -29,7 +29,7 @@ @final -class Focus(Metric): +class Focus(Metric[List[float]]): """ Implementation of Focus evaluation strategy by Arias et. al. 2022 diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 71c437757..eb6f6b90b 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -29,7 +29,7 @@ @final -class PointingGame(Metric): +class PointingGame(Metric[List[float]]): """ Implementation of the Pointing Game by Zhang et al., 2018. diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 86d475051..a367423bd 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -29,7 +29,7 @@ @final -class RelevanceMassAccuracy(Metric): +class RelevanceMassAccuracy(Metric[List[float]]): """ Implementation of the Relevance Mass Accuracy by Arras et al., 2021. diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index f05e2b6bc..41ab75037 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -29,7 +29,7 @@ @final -class RelevanceRankAccuracy(Metric): +class RelevanceRankAccuracy(Metric[List[float]]): """ Implementation of the Relevance Rank Accuracy by Arras et al., 2021. diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 1a94acb8f..bd7dc4612 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -29,7 +29,7 @@ @final -class TopKIntersection(Metric): +class TopKIntersection(Metric[List[float]]): """ Implementation of the top-k intersection by Theiner et al., 2021. diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index ee174addb..ab22ea022 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -30,7 +30,7 @@ @final -class RandomLogit(Metric): +class RandomLogit(Metric[List[float]]): """ Implementation of the Random Logit Metric by Sixt et al., 2020. diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index fba0f2231..48ab9d79b 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -37,7 +37,7 @@ @final -class AvgSensitivity(Metric): +class AvgSensitivity(Metric[List[float]]): """ Implementation of Avg-Sensitivity by Yeh at el., 2019. diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index f4d4aaa4f..0766e91b7 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -29,7 +29,7 @@ @final -class Consistency(Metric): +class Consistency(Metric[List[float]]): """ Implementation of the Consistency metric which measures the expected local consistency, i.e., the probability of the prediction label for a given datapoint coinciding with the prediction labels of other data points that diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 9ed59ebbe..10d8af9e2 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -32,7 +32,7 @@ @final -class Continuity(Metric): +class Continuity(Metric[List[float]]): """ Implementation of the Continuity test by Montavon et al., 2018. diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index adf8627e5..17e24bbb4 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -35,7 +35,7 @@ @final -class LocalLipschitzEstimate(Metric): +class LocalLipschitzEstimate(Metric[List[float]]): """ Implementation of the Local Lipschitz Estimate (or Stability) test by Alvarez-Melis et al., 2018a, 2018b. diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 7d7311045..0b2514e53 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -36,7 +36,7 @@ @final -class MaxSensitivity(Metric): +class MaxSensitivity(Metric[List[float]]): """ Implementation of Max-Sensitivity by Yeh at el., 2019. diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index e0e82c842..533c3545d 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -39,7 +39,7 @@ @final -class RelativeInputStability(Metric): +class RelativeInputStability(Metric[List[float]]): """ Relative Input Stability leverages the stability of an explanation with respect to the change in the input data. diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index fcf3d8eb1..784edfdbc 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -39,7 +39,7 @@ @final -class RelativeOutputStability(Metric): +class RelativeOutputStability(Metric[List[float]]): """ Relative Output Stability leverages the stability of an explanation with respect to the change in the output logits. diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index 40d1db44a..ff49aa2ed 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -39,7 +39,7 @@ @final -class RelativeRepresentationStability(Metric): +class RelativeRepresentationStability(Metric[List[float]]): """ Relative Representation Stability leverages the stability of an explanation with respect to the change in the output logits. From 367ad25aa649baf12d76f02fbb1a6d0d3b8e6521 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 15:13:52 +0200 Subject: [PATCH 34/58] * test fix --- quantus/metrics/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 259eff944..2199a01c1 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -38,9 +38,9 @@ from quantus.helpers.model.model_interface import ModelInterface if sys.version_info >= (3, 8): - from typing import LiteralString, final + from typing import final else: - from typing_extensions import LiteralString, final + from typing_extensions import final if sys.version_info >= (3, 11): from typing import LiteralString From 22c3b0be316ce8c15ef377010cf5f56067cba969 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 15:22:03 +0200 Subject: [PATCH 35/58] * --- quantus/metrics/faithfulness/road.py | 2 +- quantus/metrics/faithfulness/sufficiency.py | 2 +- quantus/metrics/robustness/consistency.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index b0077e190..118ef2814 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -5,7 +5,7 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index badcea75b..62862b100 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -5,7 +5,7 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional, no_type_check diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 0766e91b7..08836913f 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -5,7 +5,7 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - +from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional, no_type_check From 91c6de54f56edf3b55aa036db1b49a8cb5692deb Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 15:59:38 +0200 Subject: [PATCH 36/58] * --- quantus/metrics/base.py | 11 +++++------ quantus/metrics/faithfulness/pixel_flipping.py | 1 + 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 2199a01c1..84db890c9 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -892,7 +892,7 @@ def display_progressbar(self) -> bool: self._display_progressbar and # Don't show progress bar in github actions. - os.environ.get("GITHUB_ACTIONS") != "true" + "GITHUB_ACTIONS" not in os.environ and # Don't show progress bar when running unit tests. "PYTEST" not in os.environ @@ -902,11 +902,10 @@ def display_progressbar(self) -> bool: def disable_warnings(self) -> bool: """A helper to avoid polluting test outputs with warnings.""" return ( - not self._disable_warnings - and + self._disable_warnings # Don't show progress bar in github actions. - os.environ.get("GITHUB_ACTIONS") != "true" - and + or "GITHUB_ACTIONS" not in os.environ # Don't show progress bar when running unit tests. - "PYTEST" not in os.environ + or "PYTEST" in os.environ + ) diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 04eaf4ff2..ea1e969cd 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -5,6 +5,7 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional From 1cce0973ef8f967e6c7b7d5e8a4c88bcd5d6f9f4 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 16:16:47 +0200 Subject: [PATCH 37/58] * --- tests/metrics/test_localisation_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index 5680038c1..8d83f709f 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -768,7 +768,7 @@ def test_top_k_intersection( ) if isinstance(expected, float): - assert all(round(s, 2) == round(expected, 2) for s in scores), "Test failed." + np.testing.assert_allclose(scores, expected, atol=0.01) elif "type" in expected: assert isinstance(scores, expected["type"]), "Test failed." else: From f16bceae83ca378470c2e4ce959ba547d7f1c91a Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 11 Oct 2023 16:20:55 +0200 Subject: [PATCH 38/58] * --- pyproject.toml | 1 + quantus/helpers/perturbation_utils.py | 8 +++++++- quantus/metrics/base.py | 3 +-- quantus/metrics/base_batched.py | 2 +- quantus/metrics/faithfulness/pixel_flipping.py | 1 + quantus/metrics/faithfulness/road.py | 1 + quantus/metrics/faithfulness/sufficiency.py | 1 + quantus/metrics/robustness/consistency.py | 1 + 8 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b6e293835..a924779bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "scipy>=1.7.3", "tqdm>=4.62.3", "matplotlib>=3.3.4", + "typing_extensions; python_version <= '3.8'" ] dynamic = ["version"] diff --git a/quantus/helpers/perturbation_utils.py b/quantus/helpers/perturbation_utils.py index a17b2c3a2..71b3215a1 100644 --- a/quantus/helpers/perturbation_utils.py +++ b/quantus/helpers/perturbation_utils.py @@ -1,9 +1,15 @@ from __future__ import annotations -from typing import List, TYPE_CHECKING, Callable, Mapping, Protocol +import sys +from typing import List, TYPE_CHECKING, Callable, Mapping import numpy as np import functools +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + if TYPE_CHECKING: from quantus.helpers.model.model_interface import ModelInterface diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 84db890c9..cbdfe0f12 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -17,10 +17,10 @@ ClassVar, Dict, Generator, + Generic, Sequence, Set, TypeVar, - Generic, ) import matplotlib.pyplot as plt @@ -907,5 +907,4 @@ def disable_warnings(self) -> bool: or "GITHUB_ACTIONS" not in os.environ # Don't show progress bar when running unit tests. or "PYTEST" in os.environ - ) diff --git a/quantus/metrics/base_batched.py b/quantus/metrics/base_batched.py index 0aaa24f75..292358bf2 100644 --- a/quantus/metrics/base_batched.py +++ b/quantus/metrics/base_batched.py @@ -5,9 +5,9 @@ # Quantus project URL: . import abc +import logging from quantus.metrics.base import Metric -import logging """Aliases to smoothen transition to uniform metric API.""" diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index ea1e969cd..9b0b757cb 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -6,6 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . from __future__ import annotations + import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 118ef2814..854cbb2b5 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -6,6 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . from __future__ import annotations + import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 62862b100..2b141b600 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -6,6 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . from __future__ import annotations + import sys from typing import Any, Callable, Dict, List, Optional, no_type_check diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 08836913f..7c152aa15 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -6,6 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . from __future__ import annotations + import sys from typing import Any, Callable, Dict, List, Optional, no_type_check From 9809d6312d2dc0e6ef4150c9fbbc40555471ac1d Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:25:34 +0200 Subject: [PATCH 39/58] * mypy fixes --- quantus/metrics/axiomatic/completeness.py | 2 +- quantus/metrics/axiomatic/input_invariance.py | 2 +- quantus/metrics/axiomatic/non_sensitivity.py | 2 +- quantus/metrics/base.py | 2 +- quantus/metrics/complexity/complexity.py | 2 +- quantus/metrics/faithfulness/faithfulness_correlation.py | 2 +- quantus/metrics/faithfulness/faithfulness_estimate.py | 2 +- quantus/metrics/faithfulness/infidelity.py | 2 +- quantus/metrics/faithfulness/irof.py | 2 +- quantus/metrics/faithfulness/monotonicity.py | 4 ++-- quantus/metrics/faithfulness/monotonicity_correlation.py | 2 +- quantus/metrics/faithfulness/pixel_flipping.py | 2 +- quantus/metrics/faithfulness/region_perturbation.py | 2 +- quantus/metrics/faithfulness/road.py | 6 +++--- quantus/metrics/faithfulness/selectivity.py | 2 +- quantus/metrics/faithfulness/sensitivity_n.py | 2 +- quantus/metrics/faithfulness/sufficiency.py | 1 + quantus/metrics/localisation/focus.py | 1 + quantus/metrics/robustness/avg_sensitivity.py | 3 +++ quantus/metrics/robustness/continuity.py | 2 +- quantus/metrics/robustness/local_lipschitz_estimate.py | 3 +++ quantus/metrics/robustness/max_sensitivity.py | 3 +++ 22 files changed, 31 insertions(+), 20 deletions(-) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index 31c9becf5..e169781e0 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -308,11 +308,11 @@ def evaluate_instance( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[bool]: """ diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index 19ac2b579..7da362bcd 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -287,7 +287,7 @@ def evaluate_batch( ) # Get input shift. - input_shift = self.perturb_func.keywords["input_shift"] + input_shift = self.perturb_func.keywords["input_shift"] # type: ignore x_shifted = model.shape_input( x=x_shifted, shape=x_shifted.shape, diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index ed5918258..1683c8836 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -361,11 +361,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[int]: """ diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index cbdfe0f12..fc2ca090e 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -296,7 +296,7 @@ def __call__( # Append the content of the last results to all results. self.all_evaluation_scores.extend(self.evaluation_scores) - return self.evaluation_scores + return self.evaluation_scores # type: ignore @abstractmethod def evaluate_batch( diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 841e2027e..15da241d1 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -259,7 +259,7 @@ def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: return scipy.stats.entropy(pk=a) def evaluate_batch( - self, *args, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs + self, x_batch: np.ndarray, a_batch: np.ndarray, *args, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 0bc3dd1c2..d8fa94a36 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -374,11 +374,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float]: """ diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 2dd317f81..afd7b0a26 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -365,11 +365,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float]: """ diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 660778f21..b1704c3d6 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -413,11 +413,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float]: """ diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 4ebdee2de..7577d852d 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -374,11 +374,11 @@ def get_aoc_score(self): def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float]: """ diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index dc4d61a7a..28f679230 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -293,7 +293,7 @@ def evaluate_instance( # Copy the input x but fill with baseline values. baseline_value = utils.get_baseline_value( - value=self.perturb_func.keywords["perturb_baseline"], + value=self.perturb_func.keywords["perturb_baseline"], # type: ignore arr=x, return_shape=x.shape, # TODO. Double-check this over using = (1,). ) @@ -356,11 +356,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[bool]: """ diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 69dd9bf13..b87df2516 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -385,11 +385,11 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float]: """ diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 9b0b757cb..b391ef2e0 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -364,11 +364,11 @@ def get_auc_score(self): def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[float | np.ndarray]: """ diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index e69d977f4..d5a2f1b0e 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -420,11 +420,11 @@ def get_auc_score(self): def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[List[float]]: """ diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 854cbb2b5..89363197d 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -291,7 +291,7 @@ def evaluate_instance( for p_ix, p in enumerate(self.percentages): top_k_indices = ordered_indices[: int(self.a_size * p / 100)] - x_perturbed = self.perturb_func( + x_perturbed = self.perturb_func( # type: ignore arr=x, indices=top_k_indices, ) @@ -308,7 +308,7 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance - def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> None: """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" if self.a_size is None: self.a_size = data_batch["a_batch"][0, :, :].size @@ -351,11 +351,11 @@ def custom_postprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[List[float]]: """ diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 99077518e..2cf498183 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -380,11 +380,11 @@ def get_auc_score(self): def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[List[float]]: """ diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index fbfcfe1f5..75477bbee 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -433,11 +433,11 @@ def custom_postprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> List[Dict[str, List[float]]]: """ diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 2b141b600..79ebd8248 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -312,6 +312,7 @@ def custom_batch_preprocess( "y_pred_classes": y_pred_classes, } + @no_type_check def evaluate_batch( self, *args, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs ) -> List[float]: diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 974ae1af5..a170c5f51 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -389,6 +389,7 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: ] return quandrant_a + @no_type_check def evaluate_batch( self, *args, a_batch: np.ndarray, c_batch: np.ndarray, **kwargs ) -> List[float]: diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 48ab9d79b..bdb5733d2 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -303,6 +303,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> np.ndarray: """ @@ -318,6 +319,8 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. + args: + Unused. kwargs: Unused. diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 10d8af9e2..7ab1a0cf0 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -438,10 +438,10 @@ def aggregated_score(self): def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, + *args, **kwargs, ) -> List[Dict[str, int]]: """ diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index 17e24bbb4..1f782892b 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -309,6 +309,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> np.ndarray: """ @@ -324,6 +325,8 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + args: + Unused. kwargs: Unused. diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 0b2514e53..9903ca5e3 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -303,6 +303,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, + *args, **kwargs, ) -> np.ndarray: """ @@ -318,6 +319,8 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. + args: + Unused. kwargs: Unused. From e37de152a4e104e595ccc8f0ad3f6868d71ed519 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:30:09 +0200 Subject: [PATCH 40/58] * add xfail --- tests/metrics/test_localisation_metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index 8d83f709f..7648db7a9 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -852,7 +852,7 @@ def test_top_k_intersection( }, {"min": 0.5, "max": 1.0}, ), - ( + pytest.param( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("none_in_gt_zeros_1d_3ch"), { @@ -862,6 +862,7 @@ def test_top_k_intersection( }, }, 0.0, + marks=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), @@ -874,7 +875,7 @@ def test_top_k_intersection( }, 0.0, ), - ( + pytest.param( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("half_in_gt_zeros_1d_3ch"), { @@ -884,6 +885,7 @@ def test_top_k_intersection( }, }, 0.5, + marks=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), From fdfe768a3b348e35afff2a17d0b45f8cda161475 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:41:41 +0200 Subject: [PATCH 41/58] * add xfail --- tests/metrics/test_localisation_metrics.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index 7648db7a9..f4c413b64 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -649,7 +649,7 @@ def test_pointing_game( }, {"min": 0.1, "max": 0.25}, ), - ( + pytest.mark.xfail( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("half_in_gt_zeros_1d_3ch"), { @@ -660,6 +660,7 @@ def test_pointing_game( }, }, 0.9800000000000001, # TODO: verify correctness + mark=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), @@ -852,7 +853,7 @@ def test_top_k_intersection( }, {"min": 0.5, "max": 1.0}, ), - pytest.param( + ( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("none_in_gt_zeros_1d_3ch"), { @@ -862,7 +863,6 @@ def test_top_k_intersection( }, }, 0.0, - marks=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), @@ -875,7 +875,7 @@ def test_top_k_intersection( }, 0.0, ), - pytest.param( + ( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("half_in_gt_zeros_1d_3ch"), { @@ -885,7 +885,6 @@ def test_top_k_intersection( }, }, 0.5, - marks=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), From adf3482f1167ecc22a5806df43b427ee54e1f0ee Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:45:48 +0200 Subject: [PATCH 42/58] * add xfail --- tests/metrics/test_localisation_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index f4c413b64..bfa692e53 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -649,7 +649,7 @@ def test_pointing_game( }, {"min": 0.1, "max": 0.25}, ), - pytest.mark.xfail( + pytest.param( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("half_in_gt_zeros_1d_3ch"), { From f85310a9c12a605e9aff6a184410b4b578fe6db1 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:45:57 +0200 Subject: [PATCH 43/58] * add xfail --- tests/metrics/test_localisation_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index bfa692e53..506bdb651 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -660,7 +660,7 @@ def test_pointing_game( }, }, 0.9800000000000001, # TODO: verify correctness - mark=pytest.mark.xfail + marks=pytest.mark.xfail ), ( lazy_fixture("load_mnist_model"), From 0433a490a4a36e17607555ea5eb0d9136916fbad Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 19:54:08 +0200 Subject: [PATCH 44/58] * add xfail --- .github/workflows/python-package.yml | 3 +++ tests/metrics/test_localisation_metrics.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9ef21d883..f0925035a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -10,6 +10,9 @@ on: pull_request: workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: true jobs: diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index 506bdb651..4089b1b0e 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -625,7 +625,7 @@ def test_pointing_game( }, 0.0, ), - ( + pytest.param( lazy_fixture("load_1d_1ch_conv_model"), lazy_fixture("none_in_gt_zeros_1d_3ch"), { @@ -636,6 +636,7 @@ def test_pointing_game( }, }, 0.38, # TODO: verify correctness + marks=pytest.mark.xfail, ), ( lazy_fixture("load_mnist_model"), @@ -660,7 +661,7 @@ def test_pointing_game( }, }, 0.9800000000000001, # TODO: verify correctness - marks=pytest.mark.xfail + marks=pytest.mark.xfail, ), ( lazy_fixture("load_mnist_model"), From e537708b48af543457a5094d6e6913d647a5c513 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 17 Oct 2023 20:02:25 +0200 Subject: [PATCH 45/58] * --- quantus/metrics/base.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index fc2ca090e..380a6efb1 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -42,11 +42,6 @@ else: from typing_extensions import final -if sys.version_info >= (3, 11): - from typing import LiteralString -else: - from typing_extensions import LiteralString - D = TypeVar("D", bound=Dict[str, Any]) log = logging.getLogger(__name__) @@ -60,7 +55,7 @@ class Metric(Generic[R]): """ # Class attributes. - name: ClassVar[LiteralString] + name: ClassVar[str] data_applicability: ClassVar[Set[DataType]] model_applicability: ClassVar[Set[ModelType]] score_direction: ClassVar[ScoreDirection] From f0f90d9fc706fb0e4f95c4aa5e60d68a5018cb00 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 18 Oct 2023 14:43:32 +0200 Subject: [PATCH 46/58] * revert typing changes, update docs --- .github/workflows/codecov.yml | 3 + .github/workflows/lint.yml | 3 + docs/Makefile | 8 ++- .../source/docs_api/quantus.helpers.enums.rst | 7 ++ .../quantus.helpers.perturbation_utils.rst | 7 ++ docs/source/docs_api/quantus.helpers.rst | 2 + mypy.ini | 1 + quantus/metrics/base.py | 67 ++++++++++--------- 8 files changed, 63 insertions(+), 35 deletions(-) create mode 100644 docs/source/docs_api/quantus.helpers.enums.rst create mode 100644 docs/source/docs_api/quantus.helpers.perturbation_utils.rst diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 5af493be2..53b18ab67 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -6,6 +6,9 @@ on: pull_request: workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: true jobs: run: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 655e24b12..20f33899c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,6 +4,9 @@ on: pull_request: workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: true jobs: lint: diff --git a/docs/Makefile b/docs/Makefile index 65d2f3b15..e68457991 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,12 +12,16 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help Makefile clean # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -rst: +rst: clean @sphinx-apidoc -o source/docs_api ../quantus --module-first --separate --force + + +clean: + rm -rf source/docs_api diff --git a/docs/source/docs_api/quantus.helpers.enums.rst b/docs/source/docs_api/quantus.helpers.enums.rst new file mode 100644 index 000000000..083197318 --- /dev/null +++ b/docs/source/docs_api/quantus.helpers.enums.rst @@ -0,0 +1,7 @@ +quantus.helpers.enums module +============================ + +.. automodule:: quantus.helpers.enums + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/docs_api/quantus.helpers.perturbation_utils.rst b/docs/source/docs_api/quantus.helpers.perturbation_utils.rst new file mode 100644 index 000000000..34e462121 --- /dev/null +++ b/docs/source/docs_api/quantus.helpers.perturbation_utils.rst @@ -0,0 +1,7 @@ +quantus.helpers.perturbation\_utils module +========================================== + +.. automodule:: quantus.helpers.perturbation_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/docs_api/quantus.helpers.rst b/docs/source/docs_api/quantus.helpers.rst index df7447c63..85daab91a 100644 --- a/docs/source/docs_api/quantus.helpers.rst +++ b/docs/source/docs_api/quantus.helpers.rst @@ -22,6 +22,8 @@ Submodules quantus.helpers.asserts quantus.helpers.constants + quantus.helpers.enums + quantus.helpers.perturbation_utils quantus.helpers.plotting quantus.helpers.utils quantus.helpers.warn diff --git a/mypy.ini b/mypy.ini index 09c5feae7..1b94e61f6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,6 +12,7 @@ ignore_missing_imports = True no_site_packages = True show_none_errors = False ignore_errors = False +plugins = numpy.typing.mypy_plugin [mypy-quantus.*] disallow_untyped_defs = False diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 380a6efb1..e522d84e3 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -21,6 +21,7 @@ Sequence, Set, TypeVar, + Optional, ) import matplotlib.pyplot as plt @@ -73,10 +74,10 @@ def __init__( abs: bool, normalise: bool, normalise_func: Callable, - normalise_func_kwargs: Dict[str, ...] | None, + normalise_func_kwargs: Optional[Dict[str, Any]], return_aggregate: bool, aggregate_func: Callable, - default_plot_func: Callable[[...], None] | None, + default_plot_func: Optional[Callable], disable_warnings: bool, display_progressbar: bool, **kwargs, @@ -145,15 +146,15 @@ def __call__( self, model, x_batch: np.ndarray, - y_batch: np.ndarray | None, - a_batch: np.ndarray | None, - s_batch: np.ndarray | None, - channel_first: bool | None, - explain_func: Callable[[...], None] | None, - explain_func_kwargs: Dict[str, ...] | None, - model_predict_kwargs: Dict[str, ...] | None, - softmax: bool | None, - device: str | None = None, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: Optional[np.ndarray], + channel_first: Optional[bool], + explain_func: Optional[Callable], + explain_func_kwargs: Optional[Dict], + model_predict_kwargs: Optional[Dict], + softmax: Optional[bool], + device: Optional[str] = None, batch_size: int = 64, custom_batch: Any = None, **kwargs, @@ -301,7 +302,7 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray | None, + s_batch: Optional[np.ndarray], **kwargs, ): """ @@ -334,17 +335,17 @@ def general_preprocess( self, model, x_batch: np.ndarray, - y_batch: np.ndarray | None, - a_batch: np.ndarray | None, - s_batch: np.ndarray | None, - channel_first: bool | None, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: Optional[np.ndarray], + channel_first: Optional[bool], explain_func: Callable, - explain_func_kwargs: Dict[str, ...] | None, - model_predict_kwargs: Dict[str, ...] | None, + explain_func_kwargs: Optional[Dict[str, Any]], + model_predict_kwargs: Optional[Dict[str, Any]], softmax: bool, - device: str | None, - custom_batch: np.ndarray | None, - ) -> Dict[str, ...]: + device: Optional[str], + custom_batch: Optional[np.ndarray], + ) -> Dict[str, Any]: """ Prepares all necessary variables for evaluation. @@ -463,11 +464,11 @@ def custom_preprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: np.ndarray | None, - a_batch: np.ndarray | None, - s_batch: np.ndarray | None, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: Optional[np.ndarray], custom_batch: Any, - ) -> Dict[str, ...] | None: + ) -> Optional[Dict[str, Any]]: """ Implement this method if you need custom preprocessing of data, model alteration or simply for creating/initialising additional @@ -606,9 +607,9 @@ def custom_postprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: np.ndarray | None, - a_batch: np.ndarray | None, - s_batch: np.ndarray | None, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: Optional[np.ndarray], **kwargs, ): """ @@ -712,9 +713,9 @@ def generate_batches( def plot( self, - plot_func: Callable[[...], None] | None = None, + plot_func: Optional[Callable] = None, show: bool = True, - path_to_save: str | None = None, + path_to_save: Optional[str] = None, *args, **kwargs, ) -> None: @@ -782,7 +783,7 @@ def get_params(self) -> Dict[str, Any]: return {k: v for k, v in self.__dict__.items() if k not in attr_exclude} @final - def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: + def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: """ If `data_batch` has no `a_batch`, will compute explanations. This needs to be done on batch level to avoid OOM. Additionally will set `a_axes` property if it is None, @@ -809,8 +810,8 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]: return data_batch def custom_batch_preprocess( - self, data_batch: Dict[str, ...] - ) -> Dict[str, ...] | None: + self, data_batch: Dict[str, Any] + ) -> Optional[Dict[str, ...]]: """ Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions From aea87ad09d6ac99836ee776d611e2a142c71e0ab Mon Sep 17 00:00:00 2001 From: aaarrti Date: Wed, 18 Oct 2023 14:56:37 +0200 Subject: [PATCH 47/58] * cleanup --- quantus/metrics/faithfulness/pixel_flipping.py | 1 - quantus/metrics/faithfulness/road.py | 1 - quantus/metrics/faithfulness/sufficiency.py | 1 - quantus/metrics/robustness/avg_sensitivity.py | 8 +++++++- quantus/metrics/robustness/consistency.py | 1 - 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index b391ef2e0..96c46e8cd 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -5,7 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 89363197d..6cdb8c75d 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -5,7 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 79ebd8248..a69615c4a 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -5,7 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional, no_type_check diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index bdb5733d2..170a21dab 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -1,5 +1,11 @@ """This module contains the implementation of the Avg-Sensitivity metric.""" -import functools + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + import sys from typing import Any, Callable, Dict, List, Optional diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 7c152aa15..0766e91b7 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -5,7 +5,6 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional, no_type_check From c23a0ee1bb0b4763a33e0b44559bfdb24e273c3b Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 19 Oct 2023 17:58:45 +0200 Subject: [PATCH 48/58] * typing fix --- quantus/metrics/base.py | 2 +- quantus/metrics/faithfulness/road.py | 2 +- quantus/metrics/faithfulness/sufficiency.py | 2 +- quantus/metrics/robustness/consistency.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index e522d84e3..83041ea46 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -811,7 +811,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: def custom_batch_preprocess( self, data_batch: Dict[str, Any] - ) -> Optional[Dict[str, ...]]: + ) -> Optional[Dict[str, Any]]: """ Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 6cdb8c75d..b362cb3af 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -307,7 +307,7 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance - def custom_batch_preprocess(self, data_batch: Dict[str, ...]) -> None: + def custom_batch_preprocess(self, data_batch: Dict[str, Any]) -> None: """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" if self.a_size is None: self.a_size = data_batch["a_batch"][0, :, :].size diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index a69615c4a..98f7f32f5 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -286,7 +286,7 @@ def evaluate_instance( return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) def custom_batch_preprocess( - self, data_batch: Dict[str, ...] + self, data_batch: Dict[str, Any] ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Sufficiency evaluation on batch-level.""" model = data_batch["model"] diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 0766e91b7..21c95fdb6 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -277,8 +277,8 @@ def evaluate_instance( return np.sum(pred_same_a == pred_a) / len(diff_a) def custom_batch_preprocess( - self, data_batch: Dict[str, ...] - ) -> Dict[str, ...] | None: + self, data_batch: Dict[str, Any] + ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Consistency on batch-level.""" model = data_batch["model"] x_batch = data_batch["x_batch"] From 36dcd758c16a46f97c8d5399c826490952b5fd46 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Tue, 24 Oct 2023 10:27:23 +0200 Subject: [PATCH 49/58] * typing fix --- quantus/metrics/faithfulness/pixel_flipping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 96c46e8cd..30ee9b6ef 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -6,6 +6,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +from __future__ import annotations import sys from typing import Any, Callable, Dict, List, Optional From d0abbf80ab5c56d253f7e6af13984059a00692cd Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 27 Oct 2023 12:49:23 +0200 Subject: [PATCH 50/58] * cleanup --- pytest.ini | 3 --- tox.ini | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pytest.ini b/pytest.ini index 1a32be04c..00a92ec87 100644 --- a/pytest.ini +++ b/pytest.ini @@ -22,6 +22,3 @@ filterwarnings = error ignore::UserWarning ignore::DeprecationWarning - -# Don't suppress logs. -addopts = -s -v diff --git a/tox.ini b/tox.ini index c57927bd5..745a8d6f9 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ deps = pass_env = TF_XLA_FLAGS commands = - pytest {posargs} + pytest -s -v {posargs} [testenv:coverage] description = Run the tests with coverage From 2124abf89122147072589812cc21a81271e468d0 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 27 Oct 2023 15:34:26 +0200 Subject: [PATCH 51/58] * bump up the version --- quantus/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantus/__init__.py b/quantus/__init__.py index 00ad3ef44..c566fa8ea 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -5,7 +5,7 @@ # Quantus project URL: . # Set the correct version. -__version__ = "0.4.4" +__version__ = "0.4.5" # Expose quantus.evaluate to the user. from quantus.evaluation import evaluate From 88e1ad04ca5eb19879c96e085656f1aaf889f5d3 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 15:39:55 +0100 Subject: [PATCH 52/58] code review fixes --- pyproject.toml | 2 +- quantus/helpers/utils.py | 9 +- quantus/metrics/axiomatic/completeness.py | 23 ++- quantus/metrics/axiomatic/input_invariance.py | 17 +-- quantus/metrics/axiomatic/non_sensitivity.py | 36 ++--- quantus/metrics/base.py | 23 ++- quantus/metrics/complexity/complexity.py | 16 +-- .../complexity/effective_complexity.py | 17 +-- quantus/metrics/complexity/sparseness.py | 16 +-- .../faithfulness/faithfulness_correlation.py | 42 ++---- .../faithfulness/faithfulness_estimate.py | 40 ++---- quantus/metrics/faithfulness/infidelity.py | 37 ++--- quantus/metrics/faithfulness/irof.py | 22 +-- quantus/metrics/faithfulness/monotonicity.py | 40 ++---- .../faithfulness/monotonicity_correlation.py | 38 ++--- .../metrics/faithfulness/pixel_flipping.py | 46 +++--- .../faithfulness/region_perturbation.py | 20 ++- quantus/metrics/faithfulness/road.py | 38 ++--- quantus/metrics/faithfulness/selectivity.py | 19 ++- quantus/metrics/faithfulness/sensitivity_n.py | 48 ++----- quantus/metrics/faithfulness/sufficiency.py | 21 ++- .../localisation/attribution_localisation.py | 33 ++--- quantus/metrics/localisation/auc.py | 32 ++--- quantus/metrics/localisation/focus.py | 34 ++--- quantus/metrics/localisation/pointing_game.py | 26 +--- .../localisation/relevance_mass_accuracy.py | 32 ++--- .../localisation/relevance_rank_accuracy.py | 30 ++-- .../localisation/top_k_intersection.py | 34 ++--- .../model_parameter_randomisation.py | 131 ++++++++++-------- quantus/metrics/randomisation/random_logit.py | 52 ++----- quantus/metrics/robustness/avg_sensitivity.py | 37 +---- quantus/metrics/robustness/consistency.py | 27 ++-- quantus/metrics/robustness/continuity.py | 42 ++---- .../robustness/local_lipschitz_estimate.py | 40 ++---- quantus/metrics/robustness/max_sensitivity.py | 41 ++---- .../robustness/relative_input_stability.py | 11 +- .../robustness/relative_output_stability.py | 11 +- .../relative_representation_stability.py | 9 +- tox.ini | 1 + 39 files changed, 400 insertions(+), 793 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a924779bc..81e5fd290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,4 +108,4 @@ full = [ [build-system] requires = ["flit-core >= 3.4"] -build-backend = "flit_core.buildapi" \ No newline at end of file +build-backend = "flit_core.buildapi" diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 196bdb33f..2f2c9458d 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -9,7 +9,7 @@ import copy import re from importlib import util -from typing import Any, Dict, Optional, Sequence, Tuple, Union, List +from typing import Any, Dict, Optional, Sequence, Tuple, Union, List, TypeVar import numpy as np from skimage.segmentation import slic, felzenszwalb @@ -995,3 +995,10 @@ def calculate_auc(values: np.array, dx: int = 1): Definite integral of values. """ return np.trapz(np.array(values), dx=dx) + + +T = TypeVar("T") + + +def identity(x: T) -> T: + return x diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index e169781e0..e59670cde 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import warn from quantus.helpers.enums import ( @@ -23,6 +22,7 @@ from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func from quantus.metrics.base import Metric +from quantus.helpers.utils import identity if sys.version_info >= (3, 8): from typing import final @@ -71,12 +71,12 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - output_func: Optional[Callable] = lambda x: x, + output_func: Optional[Callable] = None, perturb_baseline: str = "black", - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -120,8 +120,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, normalise=normalise, @@ -134,10 +132,12 @@ def __init__( disable_warnings=disable_warnings, **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices # Save metric-specific attributes. if output_func is None: - output_func = lambda x: x + output_func = identity self.output_func = output_func self.perturb_func = make_perturb_func( perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline @@ -160,8 +160,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -171,7 +171,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -260,6 +259,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -312,7 +312,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[bool]: """ @@ -329,8 +328,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index 7da362bcd..b2ed9bb3f 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_shift, perturb_batch from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -63,10 +62,10 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, input_shift: Union[int, float] = -1, - perturb_func=baseline_replacement_by_shift, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -107,9 +106,6 @@ def __init__( if abs: warn.warn_absolute_operation(word="not ") - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -123,6 +119,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_shift + self.perturb_func = make_perturb_func( perturb_func, perturb_func_kwargs, input_shift=input_shift ) @@ -152,7 +151,6 @@ def __call__( softmax: Optional[bool] = None, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -251,7 +249,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - s_batch: np.ndarray, **kwargs, ) -> np.ndarray: """ @@ -267,8 +264,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - s_batch: np.ndarray - The segmentation to be evaluated on a batch-basis. Returns ------- @@ -311,6 +306,6 @@ def evaluate_batch( return score - def custom_preprocess(self, *args, **kwargs) -> None: + def custom_preprocess(self, **kwargs) -> None: """Additional explain_func assert, as the one in prepare() won't be executed when `a_batch != None.`""" asserts.assert_explain_func(explain_func=self.explain_func) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 1683c8836..e3c78f021 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -70,7 +69,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, perturb_baseline: str = "black", - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, aggregate_func: Callable = np.mean, @@ -119,9 +118,6 @@ def __init__( Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -135,6 +131,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.eps = eps self.n_samples = n_samples @@ -161,8 +160,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -172,7 +171,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -261,6 +259,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -324,30 +323,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility to the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -365,7 +352,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[int]: """ @@ -382,8 +368,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 83041ea46..112adde79 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -7,6 +7,7 @@ from __future__ import annotations +import functools import logging import os import sys @@ -30,6 +31,7 @@ from tqdm.auto import tqdm from quantus.helpers import asserts, utils, warn +from quantus.functions.normalise_func import normalise_by_max from quantus.helpers.enums import ( DataType, EvaluationCategory, @@ -64,16 +66,17 @@ class Metric(Generic[R]): evaluation_category: ClassVar[EvaluationCategory] # Instance attributes. explain_func: Callable - explain_func_kwargs: Dict[str, ...] + explain_func_kwargs: Dict[str, Any] a_axes: Sequence[int] evaluation_scores: Any all_evaluation_scores: Any + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] def __init__( self, abs: bool, normalise: bool, - normalise_func: Callable, + normalise_func: Optional[Callable], normalise_func_kwargs: Optional[Dict[str, Any]], return_aggregate: bool, aggregate_func: Callable, @@ -121,6 +124,16 @@ def __init__( kwargs: optional Keyword arguments. """ + + if aggregate_func is None: + aggregate_func = np.mean + + if normalise_func is None: + normalise_func = normalise_by_max + + if normalise_func_kwargs is not None: + normalise_func = functools.partial(normalise_func, **normalise_func_kwargs) + # Run deprecation warnings. warn.deprecation_warnings(kwargs) warn.check_kwargs(kwargs) @@ -130,7 +143,6 @@ def __init__( self.return_aggregate = return_aggregate self.aggregate_func = aggregate_func self.normalise_func = normalise_func - self.normalise_func_kwargs = normalise_func_kwargs or {} self.default_plot_func = default_plot_func # We need underscores here to avoid conflict with @property descriptor. @@ -297,7 +309,6 @@ def __call__( @abstractmethod def evaluate_batch( self, - *, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, @@ -429,7 +440,7 @@ def general_preprocess( # Normalise with specified keyword arguments if requested. if self.normalise: - a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) + a_batch = self.normalise_func(a_batch) # Take absolute if requested. if self.abs: @@ -874,7 +885,7 @@ def explain_batch( # Normalise and take absolute values of the attributions, if configured during metric instantiation. if self.normalise: - a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) + a_batch = self.normalise_func(a_batch) if self.abs: a_batch = np.abs(a_batch) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 15da241d1..1a1a1e07a 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -11,7 +11,6 @@ import numpy as np import scipy -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -62,7 +61,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -96,9 +95,6 @@ def __init__( if not abs: warn.warn_absolute_operation() - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -129,8 +125,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -140,7 +136,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -229,6 +224,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -259,7 +255,7 @@ def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: return scipy.stats.entropy(pk=a) def evaluate_batch( - self, x_batch: np.ndarray, a_batch: np.ndarray, *args, **kwargs + self, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -271,8 +267,6 @@ def evaluate_batch( The input to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index be3378a3f..e37daa30f 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -60,7 +59,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -96,9 +95,6 @@ def __init__( if not abs: warn.warn_absolute_operation() - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -132,8 +128,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -143,7 +139,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -232,6 +227,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -253,7 +249,7 @@ def evaluate_instance(self, a: np.ndarray) -> int: a = a.flatten() return int(np.sum(a > self.eps)) - def evaluate_batch(self, *args, a_batch: np.ndarray, **kwargs) -> List[int]: + def evaluate_batch(self, a_batch: np.ndarray, **kwargs) -> List[int]: """ This method performs XAI evaluation on a single batch of explanations. For more information on the specific logic, we refer the metric’s initialisation docstring. @@ -262,9 +258,6 @@ def evaluate_batch(self, *args, a_batch: np.ndarray, **kwargs) -> List[int]: ---------- a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index f63d33699..345d69537 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -64,7 +63,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -100,9 +99,6 @@ def __init__( if not abs: warn.warn_absolute_operation() - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -134,8 +130,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -145,7 +141,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -234,6 +229,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -268,7 +264,7 @@ def evaluate_instance(x: np.ndarray, a: np.ndarray) -> float: return score def evaluate_batch( - self, *args, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs + self, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -280,8 +276,6 @@ def evaluate_batch( The input to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index d8fa94a36..f444f4d33 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn @@ -75,11 +74,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -125,9 +124,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -142,8 +138,12 @@ def __init__( ) # Save metric-specific attributes. + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + if similarity_func is None: similarity_func = correlation_pearson + self.similarity_func = similarity_func self.nr_runs = nr_runs self.subset_size = subset_size @@ -169,8 +169,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -272,6 +272,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -334,32 +335,16 @@ def evaluate_instance( return similarity - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: + def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -378,7 +363,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[float]: """ @@ -395,8 +379,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index afd7b0a26..a6cce708c 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn @@ -64,11 +63,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -112,9 +111,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -131,6 +127,8 @@ def __init__( # Save metric-specific attributes. if similarity_func is None: similarity_func = correlation_pearson + if perturb_func is None: + perturb_func = baseline_replacement_by_indices self.similarity_func = similarity_func self.features_in_step = features_in_step self.perturb_func = make_perturb_func( @@ -154,8 +152,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -257,6 +255,7 @@ def __call__( model_predict_kwargs=model_predict_kwargs, softmax=softmax, device=device, + batch_size=batch_size, **kwargs, ) @@ -324,32 +323,16 @@ def evaluate_instance( similarity = self.similarity_func(a=att_sums, b=pred_deltas) return similarity - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: + def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -369,7 +352,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[float]: """ @@ -386,8 +368,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index b1704c3d6..2f1a20124 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -11,7 +11,6 @@ import numpy as np from quantus.functions.loss_func import mse -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import utils, warn from quantus.helpers.enums import ( @@ -75,11 +74,11 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -126,8 +125,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -150,6 +147,9 @@ def __init__( raise ValueError(f"loss_func must be in ['mse'] but is: {loss_func}") self.loss_func = loss_func + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + if perturb_patch_sizes is None: perturb_patch_sizes = [4] self.perturb_patch_sizes = perturb_patch_sizes @@ -180,8 +180,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -191,7 +191,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -280,6 +279,7 @@ def __call__( model_predict_kwargs=model_predict_kwargs, softmax=softmax, device=device, + batch_size=batch_size, **kwargs, ) @@ -379,30 +379,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -417,7 +405,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[float]: """ @@ -434,8 +421,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 7577d852d..1959cf15c 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( @@ -68,11 +67,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "mean", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -113,9 +112,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -129,6 +125,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.segmentation_method = segmentation_method self.nr_channels = None @@ -168,7 +167,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -257,6 +255,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -335,12 +334,8 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. @@ -378,7 +373,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[float]: """ @@ -397,8 +391,6 @@ def evaluate_batch( The explanation to be evaluated on a batch-basis. kwargs: Unused. - args: - Unused. Returns ------- diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 28f679230..2135e19b9 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -10,7 +10,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( @@ -69,11 +68,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -114,9 +113,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -130,6 +126,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.features_in_step = features_in_step self.perturb_func = make_perturb_func( @@ -153,8 +152,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -164,7 +163,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -253,6 +251,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -319,30 +318,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -360,9 +347,8 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, - ) -> List[bool]: + ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. For more information on the specific logic, we refer the metric’s initialisation docstring. @@ -377,8 +363,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index b87df2516..e01fabd72 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_spearman from quantus.helpers import asserts, warn @@ -72,11 +71,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "uniform", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -124,9 +123,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -143,6 +139,10 @@ def __init__( # Save metric-specific attributes. if similarity_func is None: similarity_func = correlation_spearman + + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + self.similarity_func = similarity_func self.eps = eps @@ -169,8 +169,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -272,6 +272,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -348,30 +349,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -389,7 +378,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[float]: """ @@ -406,8 +394,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 30ee9b6ef..826cb53cb 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -5,14 +5,11 @@ # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . - -from __future__ import annotations import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( @@ -66,11 +63,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, return_auc_per_sample: bool = False, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, @@ -114,9 +111,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - if default_plot_func is None: default_plot_func = plotting.plot_pixel_flipping_experiment @@ -133,6 +127,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.features_in_step = features_in_step self.return_auc_per_sample = return_auc_per_sample @@ -155,8 +152,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -166,7 +163,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -255,6 +251,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -314,36 +311,24 @@ def evaluate_instance( preds[i_ix] = y_pred_perturb if self.return_auc_per_sample: - return utils.calculate_auc(preds) + return float(utils.calculate_auc(preds)) return preds def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -368,9 +353,8 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, - ) -> List[float | np.ndarray]: + ) -> List[List[float]]: """ This method performs XAI evaluation on a single batch of explanations. For more information on the specific logic, we refer the metric’s initialisation docstring. @@ -385,6 +369,8 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index d5a2f1b0e..db203dfc9 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -12,7 +12,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( @@ -74,11 +73,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -124,9 +123,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - if default_plot_func is None: default_plot_func = plotting.plot_region_perturbation_experiment @@ -143,6 +139,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.patch_size = patch_size self.order = order.lower() @@ -174,8 +173,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -185,7 +184,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -274,6 +272,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -424,7 +423,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[List[float]]: """ @@ -441,8 +439,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index b362cb3af..105027e10 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import noisy_linear_imputation from quantus.helpers import warn from quantus.helpers.enums import ( @@ -70,10 +69,10 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = noisy_linear_imputation, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -111,9 +110,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -130,6 +126,10 @@ def __init__( # Save metric-specific attributes. if percentages is None: percentages = list(range(1, 100, 2)) + + if perturb_func is None: + perturb_func = noisy_linear_imputation + self.percentages = percentages self.a_size = None self.perturb_func = make_perturb_func( @@ -153,8 +153,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -164,7 +164,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -253,6 +252,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -314,11 +314,6 @@ def custom_batch_preprocess(self, data_batch: Dict[str, Any]) -> None: def custom_postprocess( self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, **kwargs, ) -> None: """ @@ -326,16 +321,8 @@ def custom_postprocess( Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. + kwargs: + Unused. Returns ------- @@ -354,7 +341,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[List[float]]: """ @@ -371,8 +357,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 2cf498183..4b5739add 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -12,7 +12,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import baseline_replacement_by_indices from quantus.helpers import plotting, utils, warn from quantus.helpers.enums import ( @@ -72,11 +71,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -117,8 +116,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max if default_plot_func is None: default_plot_func = plotting.plot_selectivity_experiment @@ -136,6 +133,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. self.patch_size = patch_size self.perturb_func = make_perturb_func( @@ -163,8 +163,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -174,7 +174,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -263,6 +262,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -384,7 +384,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[List[float]]: """ @@ -401,8 +400,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 75477bbee..820c60425 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -71,11 +71,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = baseline_replacement_by_indices, + perturb_func: Callable = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = True, - aggregate_func: Callable = np.mean, + aggregate_func: Callable = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -140,6 +140,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = baseline_replacement_by_indices + # Save metric-specific attributes. if similarity_func is None: similarity_func = correlation_pearson @@ -169,8 +172,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -180,7 +183,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -269,6 +271,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -337,30 +340,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -374,11 +365,7 @@ def custom_preprocess( def custom_postprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, **kwargs, ) -> None: """ @@ -386,16 +373,10 @@ def custom_postprocess( Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. + kwargs: + Unused. Returns ------- @@ -437,7 +418,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> List[Dict[str, List[float]]]: """ @@ -454,8 +434,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 98f7f32f5..952037779 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -12,7 +12,6 @@ import numpy as np from scipy.spatial.distance import cdist -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -69,7 +68,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -110,8 +109,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -148,8 +145,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -159,7 +156,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -248,14 +244,15 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @staticmethod def evaluate_instance( - i: int = None, - a_sim_vector: np.ndarray = None, - y_pred_classes: np.ndarray = None, + i: int, + a_sim_vector: np.ndarray, + y_pred_classes: np.ndarray, ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -313,7 +310,7 @@ def custom_batch_preprocess( @no_type_check def evaluate_batch( - self, *args, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs + self, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -327,8 +324,6 @@ def evaluate_batch( The custom input to be evaluated on an instance-basis. y_pred_classes: The class predictions of the complete input dataset. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 61f650dac..053d7d633 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -7,11 +7,10 @@ # Quantus project URL: . import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -19,7 +18,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): @@ -64,7 +62,7 @@ def __init__( normalise_func: Optional[Callable] = None, normalise_func_kwargs: Optional[Dict] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, display_progressbar: bool = False, disable_warnings: bool = False, @@ -100,8 +98,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max if not abs: warn.warn_absolute_operation() @@ -141,8 +137,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -152,7 +148,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -241,6 +236,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -298,31 +294,21 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray, optional A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - + kwargs: + Unused. Returns ------- None @@ -332,7 +318,6 @@ def custom_preprocess( def evaluate_batch( self, - *args, x_batch: np.ndarray, a_batch: np.ndarray, s_batch: np.ndarray, @@ -350,8 +335,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray A np.ndarray which contains segmentation masks that matches the input. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 5901f1aed..55163416c 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -12,7 +12,6 @@ import numpy as np from sklearn.metrics import auc, roc_curve -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -20,7 +19,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): @@ -60,7 +58,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -91,9 +89,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -124,8 +119,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -135,7 +130,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -222,6 +216,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -258,30 +253,21 @@ def evaluate_instance(a: np.ndarray, s: np.ndarray) -> float: def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: (Union[torch.nn.Module, tf.keras.Model]) - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray - A Union[np.ndarray, None] which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray A Union[np.ndarray, None] which contains segmentation masks that matches the input. - custom_batch: np.ndarray, optional - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -291,7 +277,7 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs + self, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -303,8 +289,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index a170c5f51..6382b4679 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import plotting, warn from quantus.helpers.enums import ( DataType, @@ -66,8 +65,8 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, - default_plot_func: Optional[Callable] = plotting.plot_focus, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, **kwargs, @@ -97,8 +96,8 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max + if default_plot_func is None: + default_plot_func = plotting.plot_focus # Save metric-specific attributes. self.mosaic_shape = mosaic_shape @@ -135,8 +134,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -266,13 +265,14 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) def evaluate_instance( self, a: np.ndarray, - c: np.ndarray = None, + c: np.ndarray, ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -319,10 +319,9 @@ def custom_preprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + y_batch: np.ndarray, + custom_batch: np.ndarray, + **kwargs, ) -> Dict[str, Any]: """ Implementation of custom_preprocess_batch. @@ -335,12 +334,10 @@ def custom_preprocess( A np.ndarray which contains the input data that are explained. y_batch: np.ndarray A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. custom_batch: any Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -389,9 +386,8 @@ def quadrant_bottom_right(self, a: np.ndarray) -> np.ndarray: ] return quandrant_a - @no_type_check def evaluate_batch( - self, *args, a_batch: np.ndarray, c_batch: np.ndarray, **kwargs + self, a_batch: np.ndarray, c_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -403,8 +399,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. c_batch: The custom input to be evaluated on an batch-basis. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index eb6f6b90b..04a945f12 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -63,7 +62,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -96,8 +95,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -134,8 +131,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -145,7 +142,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -234,6 +230,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -241,7 +238,7 @@ def evaluate_instance( self, a: np.ndarray, s: np.ndarray, - ) -> bool: + ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -280,30 +277,19 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray, optional A np.ndarray which contains segmentation masks that matches the input. - custom_batch: np.ndarray, optional - Gives flexibility ot the user to use for evaluation, can hold any variable. Returns ------- diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index a367423bd..532f62596 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -19,7 +18,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): @@ -62,7 +60,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -93,8 +91,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -128,8 +124,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -139,7 +135,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -228,6 +223,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -271,31 +267,21 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray, optional A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - + kwargs: + Unused. Returns ------- None @@ -304,7 +290,7 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs + self, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -316,8 +302,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 41ab75037..91079068f 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -64,7 +64,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -95,8 +95,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -130,8 +128,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -141,7 +139,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -230,6 +227,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -280,31 +278,21 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], s_batch: np.ndarray, - custom_batch: Optional[np.ndarray] = None, + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: np.ndarray, optional A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - + kwargs: + Unused. Returns ------- None @@ -313,7 +301,7 @@ def custom_preprocess( asserts.assert_segmentations(x_batch=x_batch, s_batch=s_batch) def evaluate_batch( - self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs + self, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -325,8 +313,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - args: - Unused. kwargs: Unused diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index bd7dc4612..60a69b599 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -19,7 +18,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): @@ -64,7 +62,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -99,8 +97,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -139,8 +135,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -150,7 +146,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -239,6 +234,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -286,31 +282,19 @@ def evaluate_instance( return tki def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + self, x_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. a_batch: np.ndarray, optional A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -323,7 +307,7 @@ def custom_preprocess( ) def evaluate_batch( - self, *args, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs + self, a_batch: np.ndarray, s_batch: np.ndarray, **kwargs ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -335,8 +319,6 @@ def evaluate_batch( A np.ndarray which contains pre-computed attributions i.e., explanations. s_batch: A np.ndarray which contains segmentation masks that matches the input. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 0fb677b5a..65b3b848e 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -7,12 +7,13 @@ # Quantus project URL: . import sys -from typing import Any, Callable, Collection, Dict, List, Optional, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Union, Generator + import numpy as np from tqdm.auto import tqdm +from sklearn.utils import gen_batches -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.similarity_func import correlation_spearman from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -62,7 +63,7 @@ class ModelParameterRandomisation(Metric): def __init__( self, - similarity_func: Callable = None, + similarity_func: Optional[Callable] = None, layer_order: str = "independent", seed: int = 42, return_sample_correlation: bool = False, @@ -71,7 +72,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = None, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -113,8 +114,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max super().__init__( abs=abs, @@ -159,8 +158,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -170,7 +169,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: """ @@ -253,7 +251,7 @@ def __call__( # Run deprecation warnings. warn.deprecation_warnings(kwargs) warn.check_kwargs(kwargs) - + self.batch_size = batch_size data = self.general_preprocess( model=model, x_batch=x_batch, @@ -268,54 +266,52 @@ def __call__( softmax=softmax, device=device, ) - model = data["model"] - x_batch = data["x_batch"] - y_batch = data["y_batch"] - a_batch = data["a_batch"] - - if a_batch is None: - a_batch = self.explain_batch(model, x_batch, y_batch) - + model: ModelInterface = data["model"] + # Here _batch refers to full dataset. + x_full_dataset = data["x_batch"] + y_full_dataset = data["y_batch"] + a_full_dataset = data["a_batch"] # Results are returned/saved as a dictionary not as a list as in the super-class. self.evaluation_scores = {} # Get number of iterations from number of layers. n_layers = len(list(model.get_random_layer_generator(order=self.layer_order))) - - model_iterator = tqdm( - model.get_random_layer_generator(order=self.layer_order, seed=self.seed), - total=n_layers, - disable=not self.display_progressbar, + pbar = tqdm( + total=n_layers * len(x_full_dataset), disable=not self.display_progressbar ) - - for layer_name, random_layer_model in model_iterator: - similarity_scores = [None for _ in x_batch] - - # Generate an explanation with perturbed model. - a_batch_perturbed = self.explain_batch( - model=random_layer_model, - x_batch=x_batch, - y_batch=y_batch, - ) - - batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) - for instance_id, (a_instance, a_instance_perturbed) in batch_iterator: - result = self.similarity_func( - a_instance_perturbed.flatten(), a_instance.flatten() + if self.display_progressbar: + # Set property to False, so we display only 1 pbar. + self._display_progressbar = False + + def generate_y_batches(): + for batch in gen_batches(len(a_full_dataset), batch_size): + yield a_full_dataset[batch.start : batch.stop] + + with pbar as pbar: + for layer_name, random_layer_model in model.get_random_layer_generator( + order=self.layer_order, seed=self.seed + ): + pbar.desc = layer_name + + similarity_scores = [] + # Generate explanations on modified model in batches + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size ) - similarity_scores[instance_id] = result - # Save similarity scores in a result dictionary. - self.evaluation_scores[layer_name] = similarity_scores - - # Call post-processing. - self.custom_postprocess( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - ) + for a_batch, a_batch_perturbed in zip( + generate_y_batches(), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip( + a_batch, a_batch_perturbed + ): + result = self.similarity_func( + a_instance_perturbed.flatten(), a_instance.flatten() + ) + similarity_scores.append(result) + pbar.update(1) + # Save similarity scores in a result dictionary. + self.evaluation_scores[layer_name] = similarity_scores if self.return_sample_correlation: self.evaluation_scores = self.compute_correlation_per_sample() @@ -335,11 +331,10 @@ def custom_preprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], + y_batch: np.ndarray, a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: + **kwargs, + ) -> Optional[Dict[str, np.ndarray]]: """ Implementation of custom_preprocess_batch. @@ -353,11 +348,8 @@ def custom_preprocess( A np.ndarray which contains the output labels that are explained. a_batch: np.ndarray, optional A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - + kwargs: + Unused. Returns ------- None @@ -365,6 +357,13 @@ def custom_preprocess( # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) + if a_batch is None: + a_batch_chunks = [] + for a_chunk in self.generate_explanations( + model, x_batch, y_batch, self.batch_size + ): + a_batch_chunks.extend(a_chunk) + return dict(a_batch=np.asarray(a_batch_chunks)) def compute_correlation_per_sample( self, @@ -388,6 +387,20 @@ def compute_correlation_per_sample( return corr_coeffs + def generate_explanations( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + batch_size: int, + ) -> Generator[np.ndarray, None, None]: + """Iterate over dataset in batches and generate explanations for complete dataset""" + for i in gen_batches(len(x_batch), batch_size): + x = x_batch[i.start : i.stop] + y = y_batch[i.start : i.stop] + a = self.explain_batch(model, x, y) + yield a + def evaluate_batch(self, *args, **kwargs): raise RuntimeError( "`evaluate_batch` must never be called for `ModelParameterRandomisation`." diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index ab22ea022..7a7985d91 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.similarity_func import ssim from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -65,7 +64,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -102,9 +101,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -140,8 +136,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -151,7 +147,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -240,6 +235,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -278,50 +274,21 @@ def evaluate_instance( ) ] ) - # Explain against a random class. - a_perturbed = self.explain_func( - model=model.get_model(), - inputs=np.expand_dims(x, axis=0), - targets=y_off, - **self.explain_func_kwargs, - ) - - # Normalise and take absolute values of the attributions, if True. - if self.normalise: - a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) - - if self.abs: - a_perturbed = np.abs(a_perturbed) - + a_perturbed = self.explain_batch(model, np.expand_dims(x, axis=0), y_off) return self.similarity_func(a.flatten(), a_perturbed.flatten()) def custom_preprocess( self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- @@ -333,7 +300,6 @@ def custom_preprocess( def evaluate_batch( self, - *args, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, @@ -354,8 +320,6 @@ def evaluate_batch( A np.ndarray which contains the output labels that are explained. a_batch: A np.ndarray which contains pre-computed attributions i.e., explanations. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 170a21dab..e1065f4b8 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -85,7 +85,7 @@ def __init__( upper_bound: Optional[float] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -137,9 +137,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -199,8 +196,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -210,7 +207,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -309,7 +305,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -325,8 +320,6 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. - args: - Unused. kwargs: Unused. @@ -377,32 +370,14 @@ def evaluate_batch( return self.mean_func(similarities, axis=1) - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: + def custom_preprocess(self, **kwargs) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 21c95fdb6..3b721307e 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -66,7 +66,7 @@ def __init__( normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -98,9 +98,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -137,8 +134,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -148,7 +145,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -237,15 +233,16 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @staticmethod def evaluate_instance( a: np.ndarray, - i: int = None, - a_label: np.ndarray = None, - y_pred_classes: np.ndarray = None, + i: int, + a_label: np.ndarray, + y_pred_classes: np.ndarray, ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -302,7 +299,6 @@ def custom_batch_preprocess( @no_type_check def evaluate_batch( self, - *args, a_batch: np.ndarray, i_batch: np.ndarray, a_label_batch: np.ndarray, @@ -315,13 +311,10 @@ def evaluate_batch( Parameters ---------- - args: - Unused. a_batch: - - i_batch - a_label_batch - y_pred_classes + i_batch: + a_label_batch: + y_pred_classes: kwargs Returns diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 7ab1a0cf0..55147f7dd 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -72,11 +72,11 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = translation_x_direction, + perturb_func: Optional[Callable] = None, perturb_baseline: str = "black", perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = np.mean, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -127,9 +127,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -143,6 +140,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = translation_x_direction + # Save metric-specific attributes. if similarity_func is None: similarity_func = lipschitz_constant @@ -179,8 +179,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -190,7 +190,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -279,6 +278,7 @@ def __call__( softmax=softmax, device=device, model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, **kwargs, ) @@ -358,9 +358,7 @@ def evaluate_instance( # a_perturbed = utils.expand_attribution_channel(a_perturbed, x_input)[0] if self.normalise: - a_perturbed_patch = self.normalise_func( - a_perturbed_patch.flatten(), **self.normalise_func_kwargs - ) + a_perturbed_patch = self.normalise_func(a_perturbed_patch.flatten()) if self.abs: a_perturbed_patch = np.abs(a_perturbed_patch.flatten()) @@ -373,31 +371,18 @@ def evaluate_instance( def custom_preprocess( self, - model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. x_batch: np.ndarray A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - + kwargs: + Unused. Returns ------- None. @@ -441,7 +426,6 @@ def evaluate_batch( model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, - *args, **kwargs, ) -> List[Dict[str, int]]: """ @@ -458,8 +442,6 @@ def evaluate_batch( A np.ndarray which contains the output labels that are explained. kwargs: Unused. - args: - Unused. Returns ------- diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index 1f782892b..baa2e5622 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import gaussian_noise, perturb_batch from quantus.functions.similarity_func import distance_euclidean, lipschitz_constant from quantus.helpers import asserts, warn @@ -75,12 +74,12 @@ def __init__( normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = gaussian_noise, + perturb_func: Optional[Callable] = None, perturb_mean: float = 0.0, perturb_std: float = 0.1, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -134,9 +133,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -150,6 +146,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = gaussian_noise + # Save metric-specific attributes. self.nr_samples = nr_samples @@ -199,8 +198,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -210,7 +209,6 @@ def __call__( softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -309,7 +307,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -325,8 +322,6 @@ def evaluate_batch( The output to be evaluated on a batch-basis. a_batch: np.ndarray The explanation to be evaluated on a batch-basis. - args: - Unused. kwargs: Unused. @@ -380,30 +375,15 @@ def evaluate_batch( def custom_preprocess( self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 9903ca5e3..5836175ac 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -10,9 +10,7 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np - from quantus.functions import norm_func -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.functions.similarity_func import difference from quantus.helpers import asserts, warn @@ -73,12 +71,12 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Callable = uniform_noise, + perturb_func: Optional[Callable] = None, lower_bound: float = 0.2, upper_bound: Optional[float] = None, perturb_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, - aggregate_func: Callable = np.mean, + aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: bool = False, display_progressbar: bool = False, @@ -130,9 +128,6 @@ def __init__( kwargs: optional Keyword arguments. """ - if normalise_func is None: - normalise_func = normalise_by_max - super().__init__( abs=abs, normalise=normalise, @@ -146,6 +141,9 @@ def __init__( **kwargs, ) + if perturb_func is None: + perturb_func = uniform_noise + # Save metric-specific attributes. self.nr_samples = nr_samples @@ -193,8 +191,8 @@ def __init__( def __call__( self, model, - x_batch: np.array, - y_batch: np.array, + x_batch: np.ndarray, + y_batch: np.ndarray, a_batch: Optional[np.ndarray] = None, s_batch: Optional[np.ndarray] = None, channel_first: Optional[bool] = None, @@ -204,7 +202,6 @@ def __call__( softmax: Optional[bool] = False, device: Optional[str] = None, batch_size: int = 64, - custom_batch: Optional[Any] = None, **kwargs, ) -> List[float]: """ @@ -303,7 +300,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -319,8 +315,6 @@ def evaluate_batch( The output to be evaluated on an instance-basis. a_batch: np.ndarray The explanation to be evaluated on an instance-basis. - args: - Unused. kwargs: Unused. @@ -373,30 +367,15 @@ def evaluate_batch( def custom_preprocess( self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], + **kwargs, ) -> None: """ Implementation of custom_preprocess_batch. Parameters ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. + kwargs: + Unused. Returns ------- diff --git a/quantus/metrics/robustness/relative_input_stability.py b/quantus/metrics/robustness/relative_input_stability.py index 533c3545d..1ef145947 100644 --- a/quantus/metrics/robustness/relative_input_stability.py +++ b/quantus/metrics/robustness/relative_input_stability.py @@ -71,10 +71,10 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = uniform_noise, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, - aggregate_func: Optional[Callable[[np.ndarray], np.float]] = np.mean, + aggregate_func: Optional[Callable[[np.ndarray], np.float]] = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-6, @@ -130,6 +130,10 @@ def __init__( disable_warnings=disable_warnings, **kwargs, ) + + if perturb_func is None: + perturb_func = uniform_noise + self._nr_samples = nr_samples self._eps_min = eps_min self.perturb_func = make_perturb_func( @@ -275,7 +279,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -289,8 +292,6 @@ def evaluate_batch( 1D tensor, representing predicted labels for the x_batch. a_batch: np.ndarray, optional 4D tensor with pre-computed explanations for the x_batch. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/robustness/relative_output_stability.py b/quantus/metrics/robustness/relative_output_stability.py index 784edfdbc..7a207a363 100644 --- a/quantus/metrics/robustness/relative_output_stability.py +++ b/quantus/metrics/robustness/relative_output_stability.py @@ -73,10 +73,10 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = uniform_noise, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, - aggregate_func: Optional[Callable[[np.ndarray], np.float]] = np.mean, + aggregate_func: Optional[Callable[[np.ndarray], np.float]] = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-6, @@ -132,6 +132,10 @@ def __init__( disable_warnings=disable_warnings, **kwargs, ) + + if perturb_func is None: + perturb_func = uniform_noise + self._nr_samples = nr_samples self._eps_min = eps_min self.perturb_func = make_perturb_func( @@ -282,7 +286,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -296,8 +299,6 @@ def evaluate_batch( 1D tensor, representing predicted labels for the x_batch. a_batch: np.ndarray, optional 4D tensor with pre-computed explanations for the x_batch. - args: - Unused. kwargs: Unused. diff --git a/quantus/metrics/robustness/relative_representation_stability.py b/quantus/metrics/robustness/relative_representation_stability.py index ff49aa2ed..a0f45b720 100644 --- a/quantus/metrics/robustness/relative_representation_stability.py +++ b/quantus/metrics/robustness/relative_representation_stability.py @@ -74,10 +74,10 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, ...]] = None, - perturb_func: Callable = uniform_noise, + perturb_func: Optional[Callable] = None, perturb_func_kwargs: Optional[Dict[str, ...]] = None, return_aggregate: bool = False, - aggregate_func: Optional[Callable[[np.ndarray], np.ndarray]] = np.mean, + aggregate_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-6, @@ -139,6 +139,8 @@ def __init__( disable_warnings=disable_warnings, **kwargs, ) + if perturb_func is None: + perturb_func = uniform_noise self._nr_samples = nr_samples self._eps_min = eps_min if layer_names is not None and layer_indices is not None: @@ -294,7 +296,6 @@ def evaluate_batch( x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, - *args, **kwargs, ) -> np.ndarray: """ @@ -308,8 +309,6 @@ def evaluate_batch( 1D tensor, representing predicted labels for the x_batch. a_batch: np.ndarray, optional 4D tensor with pre-computed explanations for the x_batch. - args: - Unused. kwargs: Unused. diff --git a/tox.ini b/tox.ini index 745a8d6f9..7df6bd370 100644 --- a/tox.ini +++ b/tox.ini @@ -40,6 +40,7 @@ commands = description = Check the code style deps = flake8 + flake8-bugbear commands = python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics From fbefe31b989f79a98d11f749869be5485b3238ec Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 20:18:07 +0100 Subject: [PATCH 53/58] * code review comments --- quantus/helpers/model/model_interface.py | 28 ++++++++-- quantus/helpers/model/models.py | 1 - quantus/helpers/model/pytorch_model.py | 21 +++++-- quantus/helpers/model/tf_model.py | 12 +++- quantus/metrics/axiomatic/completeness.py | 5 +- quantus/metrics/axiomatic/input_invariance.py | 3 +- quantus/metrics/axiomatic/non_sensitivity.py | 6 +- quantus/metrics/complexity/complexity.py | 4 +- .../complexity/effective_complexity.py | 3 +- quantus/metrics/complexity/sparseness.py | 3 +- .../faithfulness/faithfulness_correlation.py | 3 +- .../faithfulness/faithfulness_estimate.py | 4 +- quantus/metrics/faithfulness/infidelity.py | 3 +- quantus/metrics/faithfulness/irof.py | 4 +- quantus/metrics/faithfulness/monotonicity.py | 3 +- .../faithfulness/monotonicity_correlation.py | 2 +- .../metrics/faithfulness/pixel_flipping.py | 2 +- .../faithfulness/region_perturbation.py | 2 +- quantus/metrics/faithfulness/road.py | 2 +- quantus/metrics/faithfulness/selectivity.py | 2 +- quantus/metrics/faithfulness/sensitivity_n.py | 3 +- quantus/metrics/faithfulness/sufficiency.py | 7 ++- .../localisation/attribution_localisation.py | 2 +- quantus/metrics/localisation/auc.py | 2 +- quantus/metrics/localisation/focus.py | 2 +- quantus/metrics/localisation/pointing_game.py | 2 +- .../localisation/relevance_mass_accuracy.py | 2 +- .../localisation/relevance_rank_accuracy.py | 2 +- .../localisation/top_k_intersection.py | 2 +- .../model_parameter_randomisation.py | 56 +++++++++---------- quantus/metrics/randomisation/random_logit.py | 3 +- quantus/metrics/robustness/avg_sensitivity.py | 2 +- quantus/metrics/robustness/consistency.py | 12 +++- quantus/metrics/robustness/continuity.py | 5 +- .../robustness/local_lipschitz_estimate.py | 2 +- quantus/metrics/robustness/max_sensitivity.py | 2 +- 36 files changed, 126 insertions(+), 93 deletions(-) diff --git a/quantus/helpers/model/model_interface.py b/quantus/helpers/model/model_interface.py index 4c0a73cf9..e8876e25a 100644 --- a/quantus/helpers/model/model_interface.py +++ b/quantus/helpers/model/model_interface.py @@ -7,17 +7,19 @@ # Quantus project URL: . from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, List, Union +from typing import Any, Dict, Optional, Tuple, List, Union, Generator, TypeVar, Generic import numpy as np +M = TypeVar("M") -class ModelInterface(ABC): + +class ModelInterface(ABC, Generic[M]): """Base ModelInterface for torch and tensorflow models.""" def __init__( self, - model, + model: M, channel_first: bool = True, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, Any]] = None, @@ -107,7 +109,9 @@ def state_dict(self): raise NotImplementedError @abstractmethod - def get_random_layer_generator(self): + def get_random_layer_generator( + self, order: str = "top_down", seed: int = 42 + ) -> Generator[Tuple[str, M], None, None]: """ In every iteration yields a copy of the model with one additional layer's parameters randomized. For cascading randomization, set order (str) to 'top_down'. For independent randomization, @@ -171,3 +175,19 @@ def get_hidden_representations( 2D tensor with shape (batch_size, None) """ raise NotImplementedError() + + @abstractmethod + @property + def random_layer_generator_length(self) -> int: + """ + Count number of randomisable layers for `Model Parameter Randomisation`. + This property is needed to avoid `len(model.get_random_layer_generator())`, + because meterializing bigger models `num_layers` times in memory at ones + has shown to cause OOM errors. + + Returns + ------- + n: + Number of layers in model, which can be randomised. + """ + raise NotImplementedError diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 38d0ca004..2feb32c55 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,7 +12,6 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): - import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index af383074a..e273fb440 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -8,24 +8,25 @@ import copy from contextlib import suppress from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, List, Union +from typing import Any, Dict, Optional, Tuple, List, Union, Generator import warnings import logging import numpy as np import torch +from torch import nn from functools import lru_cache from quantus.helpers import utils from quantus.helpers.model.model_interface import ModelInterface -class PyTorchModel(ModelInterface): +class PyTorchModel(ModelInterface[nn.Module]): """Interface for torch models.""" def __init__( self, - model, + model: nn.Module, channel_first: bool = True, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, Any]] = None, @@ -232,7 +233,9 @@ def state_dict(self) -> dict: """ return self.model.state_dict() - def get_random_layer_generator(self, order: str = "top_down", seed: int = 42): + def get_random_layer_generator( + self, order: str = "top_down", seed: int = 42 + ) -> Generator[Tuple[str, nn.Module], None, None]: """ In every iteration yields a copy of the model with one additional layer's parameters randomized. For cascading randomization, set order (str) to 'top_down'. For independent randomization, @@ -446,3 +449,13 @@ def hook(module, module_in, module_out): # Cleanup. [i.remove() for i in new_hooks] return np.hstack(hidden_outputs) + + @property + def random_layer_generator_length(self) -> int: + return len( + [ + i + for i in self.model.named_modules() + if (hasattr(i[1], "reset_parameters")) + ] + ) diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 229442cc1..0e75098cb 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Dict, Optional, Tuple, List, Union +from typing import Dict, Optional, Tuple, List, Union, Generator from keras.layers import Dense from keras import activations from keras import Model @@ -23,7 +23,7 @@ from quantus.helpers import utils -class TensorFlowModel(ModelInterface): +class TensorFlowModel(ModelInterface[Model]): """Interface for tensorflow models.""" # All kwargs supported by Keras API https://keras.io/api/models/model_training_apis/. @@ -235,7 +235,9 @@ def load_state_dict(self, original_parameters): """Set model's learnable parameters.""" self.model.set_weights(original_parameters) - def get_random_layer_generator(self, order: str = "top_down", seed: int = 42): + def get_random_layer_generator( + self, order: str = "top_down", seed: int = 42 + ) -> Generator[Tuple[str, Model], None, None]: """ In every iteration yields a copy of the model with one additional layer's parameters randomized. For cascading randomization, set order (str) to 'top_down'. For independent randomization, @@ -418,3 +420,7 @@ def get_hidden_representations( i.reshape((input_batch_size, -1)) for i in internal_representation ] return np.hstack(internal_representation) + + @property + def random_layer_generator_length(self) -> int: + return len([i for i in self.model.layers if len(i.get_weights()) > 0]) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index e59670cde..b831d42ed 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -286,7 +286,7 @@ def evaluate_instance( Returns ------- - : boolean + score: boolean The evaluation results. """ x_baseline = self.perturb_func( @@ -333,9 +333,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of booleans. + The evaluation results. """ return [ diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index b2ed9bb3f..b1d1225a1 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -267,9 +267,8 @@ def evaluate_batch( Returns ------- - : np.ndarray + scores_batch: np.ndarray The evaluation results. - """ batch_size = x_batch.shape[0] diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index e3c78f021..63877eda5 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -288,7 +288,7 @@ def evaluate_instance( Returns ------- - integer + integer: The evaluation results. """ a = a.flatten() @@ -373,10 +373,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of integers. - + The evaluation results. """ return [ diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index 1a1a1e07a..f858e1257 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -272,9 +272,7 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. - + The evaluation results. """ return [self.evaluate_instance(x=x, a=a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index e37daa30f..3d8157423 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -263,8 +263,7 @@ def evaluate_batch(self, a_batch: np.ndarray, **kwargs) -> List[int]: Returns ------- - scores_batch: - List of integers. + The evaluation results. """ return [self.evaluate_instance(a) for a in a_batch] diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index 345d69537..ea2d43ade 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -281,8 +281,7 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. + The evaluation results. """ return [self.evaluate_instance(x=x, a=a) for x, a in zip(x_batch, a_batch)] diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index f444f4d33..b0ffe5a78 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -384,9 +384,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. + The evaluation results. """ return [ self.evaluate_instance(model=model, x=x, y=y, a=a) diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index a6cce708c..282579d56 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -373,11 +373,9 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. + The evaluation results. """ - return [ self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 2f1a20124..37717266b 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -426,9 +426,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. + The evaluation results. """ return [ diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 1959cf15c..61df3bf52 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -394,10 +394,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. - + The evaluation results. """ return [ self.evaluate_instance(model=model, x=x, y=y, a=a) diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 2135e19b9..76c13d9e2 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -368,9 +368,8 @@ def evaluate_batch( Returns ------- - scores_batch: - List of floats. + The evaluation results. """ return [ self.evaluate_instance(model=model, x=x, y=y, a=a) diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index e01fabd72..a1d48aa25 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -399,7 +399,7 @@ def evaluate_batch( Returns ------- - List[float] + scores_batch: The evaluation results. """ diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 826cb53cb..224f05308 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -374,7 +374,7 @@ def evaluate_batch( Returns ------- - list + scores_batch: The evaluation results. """ return [ diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index db203dfc9..560283d6c 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -444,7 +444,7 @@ def evaluate_batch( Returns ------- - list + scores_batch: The evaluation results. """ return [ diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 105027e10..08261b254 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -362,7 +362,7 @@ def evaluate_batch( Returns ------- - list + scores_batch: The evaluation results. """ return [ diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 4b5739add..0281679bd 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -405,7 +405,7 @@ def evaluate_batch( Returns ------- - list + scores_batch: The evaluation results. """ return [ diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 820c60425..48ec7430a 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -439,10 +439,9 @@ def evaluate_batch( Returns ------- - list + scores_batch: The evaluation results. """ - return [ self.evaluate_instance(model=model, x=x, y=y, a=a) for x, y, a in zip(x_batch, y_batch, a_batch) diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 952037779..f543f5386 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -310,7 +310,11 @@ def custom_batch_preprocess( @no_type_check def evaluate_batch( - self, i_batch, a_sim_vector_batch, y_pred_classes, **kwargs + self, + i_batch: np.ndarray, + a_sim_vector_batch: np.ndarray, + y_pred_classes: np.ndarray, + **kwargs, ) -> List[float]: """ This method performs XAI evaluation on a single batch of explanations. @@ -329,7 +333,6 @@ def evaluate_batch( Returns ------- - evaluation_scores: List of measured sufficiency for each entry in the batch. """ diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 053d7d633..3cece8ad6 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -340,7 +340,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: Evaluation result for batch. """ return [ diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 55163416c..6ca02f822 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -294,7 +294,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: Evaluation result for batch. """ return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 6382b4679..08f7a3b46 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -404,7 +404,7 @@ def evaluate_batch( Returns ------- - retval: + score_batch: Evaluation result for batch. """ return [self.evaluate_instance(a=a, c=c) for a, c in zip(a_batch, c_batch)] diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 04a945f12..10f413794 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -319,7 +319,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: Evaluation result for batch. """ return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 532f62596..251d48e6a 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -307,7 +307,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: A list of Any with the evaluation scores for the batch. """ return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 91079068f..3cf4621a2 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -318,7 +318,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: Evaluation result for batch. """ return [self.evaluate_instance(a=a, s=s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 60a69b599..1d4095443 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -324,7 +324,7 @@ def evaluate_batch( Returns ------- - retval: + scores_batch: Evaluation result for batch. """ return [self.evaluate_instance(a, s) for a, s in zip(a_batch, s_batch)] diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 65b3b848e..0864adb95 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -275,7 +275,7 @@ def __call__( self.evaluation_scores = {} # Get number of iterations from number of layers. - n_layers = len(list(model.get_random_layer_generator(order=self.layer_order))) + n_layers = model.random_layer_generator_length pbar = tqdm( total=n_layers * len(x_full_dataset), disable=not self.display_progressbar ) @@ -327,6 +327,33 @@ def generate_y_batches(): return self.evaluation_scores + def compute_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + assert isinstance(self.evaluation_scores, dict), ( + "To compute the average correlation coefficient per sample for " + "Model Parameter Randomisation Test, 'last_result' " + "must be of type dict." + ) + layer_length = len( + self.evaluation_scores[list(self.evaluation_scores.keys())[0]] + ) + results: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in results: + for layer in self.evaluation_scores: + results[sample].append(float(self.evaluation_scores[layer][sample])) + results[sample] = np.mean(results[sample]) + + corr_coeffs = list(results.values()) + + return corr_coeffs + + def evaluate_batch(self, *args, **kwargs): + raise RuntimeError( + "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + ) + def custom_preprocess( self, model: ModelInterface, @@ -365,28 +392,6 @@ def custom_preprocess( a_batch_chunks.extend(a_chunk) return dict(a_batch=np.asarray(a_batch_chunks)) - def compute_correlation_per_sample( - self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - assert isinstance(self.evaluation_scores, dict), ( - "To compute the average correlation coefficient per sample for " - "Model Parameter Randomisation Test, 'last_result' " - "must be of type dict." - ) - layer_length = len( - self.evaluation_scores[list(self.evaluation_scores.keys())[0]] - ) - results: Dict[int, list] = {sample: [] for sample in range(layer_length)} - - for sample in results: - for layer in self.evaluation_scores: - results[sample].append(float(self.evaluation_scores[layer][sample])) - results[sample] = np.mean(results[sample]) - - corr_coeffs = list(results.values()) - - return corr_coeffs - def generate_explanations( self, model: ModelInterface, @@ -400,8 +405,3 @@ def generate_explanations( y = y_batch[i.start : i.stop] a = self.explain_batch(model, x, y) yield a - - def evaluate_batch(self, *args, **kwargs): - raise RuntimeError( - "`evaluate_batch` must never be called for `ModelParameterRandomisation`." - ) diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index 7a7985d91..b5e9c9e0d 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -325,7 +325,8 @@ def evaluate_batch( Returns ------- - + scores_batch: + Evaluation results. """ return [ self.evaluate_instance(model, x, y, a) diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index e1065f4b8..5d4b6195a 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -325,7 +325,7 @@ def evaluate_batch( Returns ------- - : np.ndarray + scores_batch: np.ndarray The batched evaluation results. """ batch_size = x_batch.shape[0] diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 3b721307e..93240fd6e 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -302,7 +302,7 @@ def evaluate_batch( a_batch: np.ndarray, i_batch: np.ndarray, a_label_batch: np.ndarray, - y_pred_classes, + y_pred_classes: np.ndarray, **kwargs, ) -> List[float]: """ @@ -312,14 +312,20 @@ def evaluate_batch( Parameters ---------- a_batch: + Batch of explanation to be evaluated. i_batch: + Batch of segmentations to be evaluated. a_label_batch: + Batch of discretised attribution labels. y_pred_classes: - kwargs + The class predictions of the complete input dataset. + kwargs: + Unused. Returns ------- - + scores_batch: + Evaluation results. """ return [ diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 55147f7dd..14e8fc1f1 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -415,7 +415,7 @@ def aggregated_score(self): self.similarity_func( self.evaluation_scores[sample][self.nr_patches], self.evaluation_scores[sample][ix_patch], - ) + ) # noqa for ix_patch in range(self.nr_patches) for sample in self.evaluation_scores.keys() ] @@ -445,7 +445,8 @@ def evaluate_batch( Returns ------- - + scores_batch: + Evaluation results. """ return [ self.evaluate_instance(model=model, x=x, y=y) diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index baa2e5622..737b93bfa 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -327,7 +327,7 @@ def evaluate_batch( Returns ------- - : np.ndarray + scores_batch: np.ndarray The batched evaluation results. """ diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 5836175ac..8ab236386 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -320,7 +320,7 @@ def evaluate_batch( Returns ------- - : np.ndarray + scores_batch: np.ndarray The batched evaluation results. """ batch_size = x_batch.shape[0] From 27e908b1f0146870ff4b8771110bf154e1dc7f94 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 20:19:49 +0100 Subject: [PATCH 54/58] * code review comments --- quantus/helpers/model/model_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantus/helpers/model/model_interface.py b/quantus/helpers/model/model_interface.py index e8876e25a..3d3fc7605 100644 --- a/quantus/helpers/model/model_interface.py +++ b/quantus/helpers/model/model_interface.py @@ -176,8 +176,8 @@ def get_hidden_representations( """ raise NotImplementedError() - @abstractmethod @property + @abstractmethod def random_layer_generator_length(self) -> int: """ Count number of randomisable layers for `Model Parameter Randomisation`. From 8284075109db91032b8135a2c79f7f0263bbc276 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 20:59:50 +0100 Subject: [PATCH 55/58] * code review comments --- quantus/helpers/model/pytorch_model.py | 28 +++-- quantus/helpers/model/tf_model.py | 1 - quantus/metrics/base.py | 128 +++++++++++++------- quantus/metrics/faithfulness/road.py | 4 +- quantus/metrics/faithfulness/sufficiency.py | 6 +- quantus/metrics/robustness/consistency.py | 8 +- tox.ini | 1 - 7 files changed, 105 insertions(+), 71 deletions(-) diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index e273fb440..60f1c01ad 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -87,8 +87,11 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info("Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) + logging.info( + "Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", + named_module[0], + ) break return linear_model @@ -119,8 +122,10 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info("Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") + logging.info( + "Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -134,12 +139,14 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info("Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output.") + logging.info( + "Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -366,7 +373,6 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 0e75098cb..22a7a8c85 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -363,7 +363,6 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 112adde79..14861af8c 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -23,6 +23,9 @@ Set, TypeVar, Optional, + Union, + TYPE_CHECKING, + TypedDict, ) import matplotlib.pyplot as plt @@ -45,6 +48,10 @@ else: from typing_extensions import final +if TYPE_CHECKING: + import keras + from torch import nn + D = TypeVar("D", bound=Dict[str, Any]) log = logging.getLogger(__name__) @@ -156,9 +163,9 @@ def __init__( def __call__( self, - model, + model: Union[keras.Model, nn.Module, None], x_batch: np.ndarray, - y_batch: Optional[np.ndarray], + y_batch: np.ndarray, a_batch: Optional[np.ndarray], s_batch: Optional[np.ndarray], channel_first: Optional[bool], @@ -344,9 +351,9 @@ def evaluate_batch( @final def general_preprocess( self, - model, + model: Union[keras.Model, nn.Module, None], x_batch: np.ndarray, - y_batch: Optional[np.ndarray], + y_batch: np.ndarray, a_batch: Optional[np.ndarray], s_batch: Optional[np.ndarray], channel_first: Optional[bool], @@ -433,7 +440,6 @@ def general_preprocess( self.explain_func_kwargs["device"] = device if a_batch is not None: - # If no explanations provided, we compute them ob batch-level to avoid OOM. a_batch = utils.expand_attribution_channel(a_batch, x_batch) asserts.assert_attributions(x_batch=x_batch, a_batch=a_batch) self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) @@ -447,6 +453,7 @@ def general_preprocess( a_batch = np.abs(a_batch) else: + # If no explanations provided, we will compute them on batch-level to avoid OOM. asserts.assert_explain_func(explain_func=self.explain_func) # Initialize data dictionary. @@ -475,7 +482,7 @@ def custom_preprocess( self, model: ModelInterface, x_batch: np.ndarray, - y_batch: Optional[np.ndarray], + y_batch: np.ndarray, a_batch: Optional[np.ndarray], s_batch: Optional[np.ndarray], custom_batch: Any, @@ -523,10 +530,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: np.ndarray | None, - >>> a_batch: np.ndarray | None, + >>> y_batch: Optional[np.ndarray], + >>> a_batch: Optional[np.ndarray], >>> s_batch: np.ndarray, - >>> custom_batch: np.ndarray | None, + >>> custom_batch: Optional[np.ndarray], >>> ) -> Dict[str, Any]: >>> return {'my_new_variable': np.mean(x_batch)} >>> @@ -534,8 +541,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: np.ndarray | None, - >>> a: np.ndarray | None, + >>> y: Optional[np.ndarray], + >>> a: Optional[np.ndarray], >>> s: np.ndarray, >>> my_new_variable: np.float, >>> ) -> float: @@ -545,10 +552,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: np.ndarray | None, - >>> a_batch: np.ndarray | None, + >>> y_batch: Optional[np.ndarray], + >>> a_batch: Optional[np.ndarray], >>> s_batch: np.ndarray, - >>> custom_batch: np.ndarray | None, + >>> custom_batch: Optional[np.ndarray], >>> ) -> Dict[str, Any]: >>> return {'my_new_variable_batch': np.arange(len(x_batch))} >>> @@ -556,8 +563,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: np.ndarray | None, - >>> a: np.ndarray | None, + >>> y: Optional[np.ndarray], + >>> a: Optional[np.ndarray], >>> s: np.ndarray, >>> my_new_variable: np.int, >>> ) -> float: @@ -568,10 +575,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: np.ndarray | None, - >>> a_batch: np.ndarray | None, + >>> y_batch: Optional[np.ndarray], + >>> a_batch: Optional[np.ndarray], >>> s_batch: np.ndarray, - >>> custom_batch: np.ndarray | None, + >>> custom_batch: Optional[np.ndarray], >>> ) -> Dict[str, Any]: >>> return {'x_batch': x_batch - np.mean(x_batch, axis=0)} >>> @@ -579,8 +586,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: np.ndarray | None, - >>> a: np.ndarray | None, + >>> y: Optional[np.ndarray], + >>> a: Optional[np.ndarray], >>> s: np.ndarray, >>> ) -> float: @@ -590,10 +597,10 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x_batch: np.ndarray, - >>> y_batch: np.ndarray | None, - >>> a_batch: np.ndarray | None, + >>> y_batch: Optional[np.ndarray], + >>> a_batch: Optional[np.ndarray], >>> s_batch: np.ndarray, - >>> custom_batch: np.ndarray | None, + >>> custom_batch: Optional[np.ndarray], >>> ) -> None: >>> if np.any(np.all(a_batch < 0, axis=0)): >>> raise ValueError("Attributions must not be all negative") @@ -606,8 +613,8 @@ def custom_preprocess( >>> self, >>> model: ModelInterface, >>> x: np.ndarray, - >>> y: np.ndarray | None, - >>> a: np.ndarray | None, + >>> y: Optional[np.ndarray], + >>> a: Optional[np.ndarray], >>> s: np.ndarray, >>> ) -> float: @@ -644,7 +651,7 @@ def custom_postprocess( Returns ------- - any + any: Can be implemented, optionally by the child class. """ pass @@ -680,7 +687,7 @@ def generate_batches( Returns ------- - iterator + iterator: Each iterator output element is a keyword argument dictionary (string keys). """ @@ -769,9 +776,7 @@ def plot( plt.savefig(fname=path_to_save, dpi=400) def interpret_scores(self): - """ - Get an interpretation of the scores. - """ + """Get an interpretation of the scores.""" print(self.__init__.__doc__.split(".")[1].split("References")[0]) @property @@ -781,7 +786,7 @@ def get_params(self) -> Dict[str, Any]: Returns ------- - dict + dict: A dictionary with attributes if not excluded from pre-determined list. """ attr_exclude = [ @@ -799,6 +804,16 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: If `data_batch` has no `a_batch`, will compute explanations. This needs to be done on batch level to avoid OOM. Additionally will set `a_axes` property if it is None, this can be done earliest after we have first `a_batch`. + + Parameters + ---------- + data_batch: + A single entry yielded from the generator return by `self.generate_batches(...)` + + Returns + ------- + data_batch: + Dictionary, which is ready to be passed down to `self.evaluate_batch`. """ x_batch = data_batch["x_batch"] @@ -815,13 +830,18 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: if self.a_axes is None: self.a_axes = utils.infer_attribution_axes(a_batch, x_batch) - custom_batch = self.custom_batch_preprocess(data_batch) + custom_batch = self.custom_batch_preprocess(**data_batch) if custom_batch is not None: data_batch.update(custom_batch) return data_batch def custom_batch_preprocess( - self, data_batch: Dict[str, Any] + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, ) -> Optional[Dict[str, Any]]: """ Implement this method if you need custom preprocessing of data @@ -830,18 +850,28 @@ def custom_batch_preprocess( Parameters ---------- - data_batch + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + a_batch: + A np.ndarray which contains pre-computed attributions i.e., explanations. + kwargs: + Optional, metric-specific parameters. Returns ------- - + dict: + Optional dictionary with additional kwargs, which will be passed to `self.evaluate_batch(...)` """ pass @final def explain_batch( self, - model: ModelInterface, + model: Union[ModelInterface, keras.Model, nn.Module], x_batch: np.ndarray, y_batch: np.ndarray, ) -> np.ndarray: @@ -849,22 +879,28 @@ def explain_batch( Compute explanations, normalize and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. + It will do few things: + - call model.shape_input (if ModelInterface instance was provided) + - unwrap model (if ModelInterface instance was provided) + - call explain_func + - expand attribution channel + - (optionally) normalize a_batch + - (optionally) take np.abs of a_batch + Parameters ------- - model: + A model that is subject to explanation. x_batch: + A np.ndarray which contains the input data that are explained. y_batch: + A np.ndarray which contains the output labels that are explained. - - It will do few things: - - call model.shape_input - - unwrap model - - call explain_func - - expand attribution channel - - (optionally) normalize a_batch - - (optionally) take np.abs of a_batch + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. """ if isinstance(model, ModelInterface): diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 08261b254..7e4da4106 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -307,10 +307,10 @@ def evaluate_instance( # Return list of booleans for each percentage. return results_instance - def custom_batch_preprocess(self, data_batch: Dict[str, Any]) -> None: + def custom_batch_preprocess(self, a_batch: np.ndarray, **kwargs) -> None: """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" if self.a_size is None: - self.a_size = data_batch["a_batch"][0, :, :].size + self.a_size = a_batch[0, :, :].size def custom_postprocess( self, diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index f543f5386..c931bb82e 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -12,6 +12,7 @@ import numpy as np from scipy.spatial.distance import cdist +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -283,12 +284,9 @@ def evaluate_instance( return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) def custom_batch_preprocess( - self, data_batch: Dict[str, Any] + self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Sufficiency evaluation on batch-level.""" - model = data_batch["model"] - a_batch = data_batch["a_batch"] - x_batch = data_batch["x_batch"] a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) dist_matrix = cdist(a_batch_flat, a_batch_flat, self.distance_func, V=None) diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index 93240fd6e..f67c8638c 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -12,7 +12,7 @@ import numpy as np from quantus.functions.discretise_func import top_n_sign -from quantus.functions.normalise_func import normalise_by_max +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers import warn from quantus.helpers.enums import ( DataType, @@ -274,16 +274,12 @@ def evaluate_instance( return np.sum(pred_same_a == pred_a) / len(diff_a) def custom_batch_preprocess( - self, data_batch: Dict[str, Any] + self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Consistency on batch-level.""" - model = data_batch["model"] - x_batch = data_batch["x_batch"] x_input = model.shape_input( x_batch, x_batch[0].shape, channel_first=True, batched=True ) - - a_batch = data_batch["a_batch"] a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) a_labels = np.array(list(map(self.discretise_func, a_batch_flat))) diff --git a/tox.ini b/tox.ini index 7df6bd370..745a8d6f9 100644 --- a/tox.ini +++ b/tox.ini @@ -40,7 +40,6 @@ commands = description = Check the code style deps = flake8 - flake8-bugbear commands = python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics From 669c8ffcc7b6bdee3d29f7cd3d406297d8c3d287 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 21:26:45 +0100 Subject: [PATCH 56/58] * remove TypedDict --- quantus/metrics/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 14861af8c..1ad7852c4 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -25,7 +25,6 @@ Optional, Union, TYPE_CHECKING, - TypedDict, ) import matplotlib.pyplot as plt From 0f535f07b0680013d7b5527a967991fc6ba4fa59 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Fri, 3 Nov 2023 21:27:48 +0100 Subject: [PATCH 57/58] * --- quantus/metrics/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 1ad7852c4..e5817c610 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -321,7 +321,7 @@ def evaluate_batch( a_batch: np.ndarray, s_batch: Optional[np.ndarray], **kwargs, - ): + ) -> R: """ Evaluates model and attributes on a single data batch and returns the batched evaluation result. From a268d105cc7940018737bab1c8fbd5438748737c Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sat, 4 Nov 2023 00:39:45 +0100 Subject: [PATCH 58/58] mypy fixes --- quantus/metrics/base.py | 10 +++++ .../metrics/faithfulness/pixel_flipping.py | 6 +-- quantus/metrics/localisation/focus.py | 2 +- quantus/metrics/localisation/pointing_game.py | 1 - .../localisation/relevance_rank_accuracy.py | 2 - .../model_parameter_randomisation.py | 40 ++++++++++++------- quantus/metrics/robustness/avg_sensitivity.py | 1 - quantus/metrics/robustness/continuity.py | 1 - 8 files changed, 40 insertions(+), 23 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index e5817c610..9a6cece5c 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -25,6 +25,7 @@ Optional, Union, TYPE_CHECKING, + no_type_check, ) import matplotlib.pyplot as plt @@ -160,6 +161,7 @@ def __init__( self.evaluation_scores = [] self.all_evaluation_scores = [] + @no_type_check def __call__( self, model: Union[keras.Model, nn.Module, None], @@ -313,6 +315,7 @@ def __call__( return self.evaluation_scores # type: ignore @abstractmethod + @no_type_check def evaluate_batch( self, model: ModelInterface, @@ -479,12 +482,14 @@ def general_preprocess( def custom_preprocess( self, + *, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: Optional[np.ndarray], s_batch: Optional[np.ndarray], custom_batch: Any, + **kwargs, ) -> Optional[Dict[str, Any]]: """ Implement this method if you need custom preprocessing of data, @@ -514,6 +519,9 @@ def custom_preprocess( A np.ndarray which contains segmentation masks that matches the input. custom_batch: any Gives flexibility to the inheriting metric to use for evaluation, can hold any variable. + kwargs: + Optional, metric-specific parameters. + Returns ------- @@ -622,6 +630,7 @@ def custom_preprocess( def custom_postprocess( self, + *, model: ModelInterface, x_batch: np.ndarray, y_batch: Optional[np.ndarray], @@ -836,6 +845,7 @@ def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]: def custom_batch_preprocess( self, + *, model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 224f05308..a5d173aea 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -29,7 +29,7 @@ @final -class PixelFlipping(Metric[List[float]]): +class PixelFlipping(Metric[Union[float, List[float]]]): """ Implementation of Pixel-Flipping experiment by Bach et al., 2015. @@ -261,7 +261,7 @@ def evaluate_instance( x: np.ndarray, y: np.ndarray, a: np.ndarray, - ) -> List[float]: + ) -> Union[float, List[float]]: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -354,7 +354,7 @@ def evaluate_batch( y_batch: np.ndarray, a_batch: np.ndarray, **kwargs, - ) -> List[List[float]]: + ) -> List[Union[float, List[float]]]: """ This method performs XAI evaluation on a single batch of explanations. For more information on the specific logic, we refer the metric’s initialisation docstring. diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index 08f7a3b46..f9522efa2 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -7,7 +7,7 @@ # Quantus project URL: . import sys -from typing import Any, Callable, Dict, List, Optional, no_type_check +from typing import Any, Callable, Dict, List, Optional import numpy as np diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 10f413794..328277155 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -18,7 +18,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 3cf4621a2..9bd80d6ed 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.helpers import asserts, warn from quantus.helpers.enums import ( DataType, @@ -19,7 +18,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.base import Metric if sys.version_info >= (3, 8): diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 0864adb95..9abb99a61 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -7,7 +7,16 @@ # Quantus project URL: . import sys -from typing import Any, Callable, Collection, Dict, List, Optional, Union, Generator +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Union, + Generator, +) import numpy as np @@ -266,7 +275,7 @@ def __call__( softmax=softmax, device=device, ) - model: ModelInterface = data["model"] + model: ModelInterface = data["model"] # type: ignore # Here _batch refers to full dataset. x_full_dataset = data["x_batch"] y_full_dataset = data["y_batch"] @@ -349,11 +358,6 @@ def compute_correlation_per_sample( return corr_coeffs - def evaluate_batch(self, *args, **kwargs): - raise RuntimeError( - "`evaluate_batch` must never be called for `ModelParameterRandomisation`." - ) - def custom_preprocess( self, model: ModelInterface, @@ -384,13 +388,16 @@ def custom_preprocess( # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) - if a_batch is None: - a_batch_chunks = [] - for a_chunk in self.generate_explanations( - model, x_batch, y_batch, self.batch_size - ): - a_batch_chunks.extend(a_chunk) - return dict(a_batch=np.asarray(a_batch_chunks)) + if a_batch is not None: + # Just to silence mypy warnings + return None + + a_batch_chunks = [] + for a_chunk in self.generate_explanations( + model, x_batch, y_batch, self.batch_size + ): + a_batch_chunks.extend(a_chunk) + return dict(a_batch=np.asarray(a_batch_chunks)) def generate_explanations( self, @@ -405,3 +412,8 @@ def generate_explanations( y = y_batch[i.start : i.stop] a = self.explain_batch(model, x, y) yield a + + def evaluate_batch(self, *args, **kwargs): + raise RuntimeError( + "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + ) diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 5d4b6195a..13b71cc7d 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -12,7 +12,6 @@ import numpy as np from quantus.functions import norm_func -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import perturb_batch, uniform_noise from quantus.functions.similarity_func import difference from quantus.helpers import asserts, warn diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index 14e8fc1f1..cfb7fba34 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -11,7 +11,6 @@ import numpy as np -from quantus.functions.normalise_func import normalise_by_max from quantus.functions.perturb_func import translation_x_direction from quantus.functions.similarity_func import lipschitz_constant from quantus.helpers import asserts, utils, warn