From 9468691db300aa672788abfa58432dc620c5da8a Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Wed, 13 Jul 2022 18:16:41 +0200 Subject: [PATCH] Added progressbar to tdqm --- quantus/helpers/pytorch_model.py | 6 +- quantus/helpers/tf_model.py | 6 +- quantus/helpers/warn_func.py | 5 +- quantus/metrics/axiomatic_metrics.py | 18 +++++- quantus/metrics/base.py | 1 - quantus/metrics/complexity_metrics.py | 18 +++++- quantus/metrics/faithfulness_metrics.py | 77 +++++++++++++++++++----- quantus/metrics/localisation_metrics.py | 10 ++- quantus/metrics/randomisation_metrics.py | 17 ++++-- quantus/metrics/robustness_metrics.py | 31 +++++++--- 10 files changed, 148 insertions(+), 41 deletions(-) diff --git a/quantus/helpers/pytorch_model.py b/quantus/helpers/pytorch_model.py index 718a313b1..0ad03c07b 100644 --- a/quantus/helpers/pytorch_model.py +++ b/quantus/helpers/pytorch_model.py @@ -35,7 +35,11 @@ def predict(self, x, **kwargs): return pred.cpu().numpy() def shape_input( - self, x: np.array, shape: Tuple[int, ...], channel_first: Optional[bool] = None, batch: bool = False + self, + x: np.array, + shape: Tuple[int, ...], + channel_first: Optional[bool] = None, + batch: bool = False, ): """ Reshape input into model expected input. diff --git a/quantus/helpers/tf_model.py b/quantus/helpers/tf_model.py index dae0d45d6..4815bdc1e 100644 --- a/quantus/helpers/tf_model.py +++ b/quantus/helpers/tf_model.py @@ -40,7 +40,11 @@ def predict(self, x, **kwargs): return new_model(x, training=False).numpy() def shape_input( - self, x: np.array, shape: Tuple[int, ...], channel_first: Optional[bool] = None, batch: bool = False + self, + x: np.array, + shape: Tuple[int, ...], + channel_first: Optional[bool] = None, + batch: bool = False, ): """ Reshape input into model expected input. diff --git a/quantus/helpers/warn_func.py b/quantus/helpers/warn_func.py index 5f36d8473..6ce2edfc1 100644 --- a/quantus/helpers/warn_func.py +++ b/quantus/helpers/warn_func.py @@ -70,5 +70,6 @@ def deprecation_warnings(kwargs: dict = {}) -> None: text = "argument 'nr_channels' is deprecated and will be removed in future versions.\n" if "max_steps_per_input" in kwargs: text = "argument 'max_steps_per_input' is deprecated and will be removed in future versions.\n" - - if text != '\n': print(text) + + if text != "\n": + print(text) diff --git a/quantus/metrics/axiomatic_metrics.py b/quantus/metrics/axiomatic_metrics.py index 8cfe83666..80d7fcdb5 100644 --- a/quantus/metrics/axiomatic_metrics.py +++ b/quantus/metrics/axiomatic_metrics.py @@ -200,7 +200,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -422,7 +426,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -640,7 +648,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 56b244e49..c01cadb83 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -115,7 +115,6 @@ def interpret_scores(self) -> None: """ print(self.__init__.__doc__.split(".")[1].split("References")[0]) - # print(self.__call__.__doc__.split("callable.")[1].split("Parameters")[0]) @property def get_params(self) -> dict: diff --git a/quantus/metrics/complexity_metrics.py b/quantus/metrics/complexity_metrics.py index 6c93d6920..581be6340 100644 --- a/quantus/metrics/complexity_metrics.py +++ b/quantus/metrics/complexity_metrics.py @@ -184,7 +184,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -383,7 +387,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -575,7 +583,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: diff --git a/quantus/metrics/faithfulness_metrics.py b/quantus/metrics/faithfulness_metrics.py index 5a16080b4..714881c78 100644 --- a/quantus/metrics/faithfulness_metrics.py +++ b/quantus/metrics/faithfulness_metrics.py @@ -211,7 +211,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -441,7 +445,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -672,7 +680,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a) in iterator: @@ -920,7 +930,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -1157,7 +1171,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -1404,7 +1422,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -1650,7 +1672,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for sample, (x, y, a) in iterator: @@ -1753,7 +1777,9 @@ def __call__( self.last_results[sample] = sub_results if self.return_aggregate: - print("A 'return_aggregate' functionality is not implemented for this metric.") + print( + "A 'return_aggregate' functionality is not implemented for this metric." + ) self.all_results.append(self.last_results) @@ -1946,7 +1972,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for sample, (x, y, a) in iterator: @@ -2035,7 +2063,9 @@ def __call__( self.last_results[sample] = sub_results if self.return_aggregate: - print("A 'return_aggregate' functionality is not implemented for this metric.") + print( + "A 'return_aggregate' functionality is not implemented for this metric." + ) self.all_results.append(self.last_results) @@ -2256,7 +2286,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for sample, (x, y, a) in iterator: @@ -2522,7 +2554,11 @@ def __call__( if not self.display_progressbar: iterator = zip(x_batch_s, y_batch, a_batch) else: - iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s)) + iterator = tqdm( + zip(x_batch_s, y_batch, a_batch), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for x, y, a in iterator: @@ -2779,7 +2815,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) self.last_results = {str(k): 0 for k in self.percentages} @@ -2818,7 +2856,9 @@ def __call__( # Calculate accuracy for every number of most important pixels removed. if self.return_aggregate: - print("A 'return_aggregate' functionality is not implemented for this metric.") + print( + "A 'return_aggregate' functionality is not implemented for this metric." + ) for k in self.last_results: self.last_results[k] = self.last_results[k] / len(x_batch_s) @@ -2974,7 +3014,9 @@ def __call__( a_sim_matrix[dist_matrix <= self.threshold] = 1 # Predict on input. - x_input = model.shape_input(x_batch, x_batch[0].shape, channel_first=True, batch=True) + x_input = model.shape_input( + x_batch, x_batch[0].shape, channel_first=True, batch=True + ) y_pred_classes = np.argmax( model.predict(x_input, softmax_act=True, **self.kwargs), axis=1 ).flatten() @@ -2986,6 +3028,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, a_sim_matrix)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, a_sim) in iterator: @@ -3003,7 +3046,9 @@ def __call__( ) if self.return_aggregate: - self.last_results = [self.aggregate_func(self.last_results) / len(self.last_results)] + self.last_results = [ + self.aggregate_func(self.last_results) / len(self.last_results) + ] self.all_results.append(self.last_results) diff --git a/quantus/metrics/localisation_metrics.py b/quantus/metrics/localisation_metrics.py index fed7b79f7..70bd4e429 100644 --- a/quantus/metrics/localisation_metrics.py +++ b/quantus/metrics/localisation_metrics.py @@ -189,6 +189,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: @@ -214,7 +215,9 @@ def __call__( else: hit = bool(s[max_index]) - self.last_results.append(hit) # ratio = np.sum(binary_mask) / float(binary_mask.shape[0] * binary_mask.shape[1]) + self.last_results.append( + hit + ) # ratio = np.sum(binary_mask) / float(binary_mask.shape[0] * binary_mask.shape[1]) if self.return_aggregate: self.last_results = [self.aggregate_func(self.last_results)] @@ -403,6 +406,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: @@ -638,6 +642,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: @@ -846,6 +851,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: @@ -1052,6 +1058,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: @@ -1249,6 +1256,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, s) in iterator: diff --git a/quantus/metrics/randomisation_metrics.py b/quantus/metrics/randomisation_metrics.py index 2e8678ea3..ae7455ea3 100644 --- a/quantus/metrics/randomisation_metrics.py +++ b/quantus/metrics/randomisation_metrics.py @@ -98,7 +98,7 @@ def __call__( a_batch: Union[np.array, None], s_batch: Union[np.array, None] = None, *args, - **kwargs + **kwargs, ) -> List[float]: """ This implementation represents the main logic of the metric and makes the class object callable. @@ -194,7 +194,10 @@ def __call__( list(model.get_random_layer_generator(order=self.layer_order)) ) n_iterations = n_layers * len(a_batch) - pbar = tqdm(total=n_iterations) + pbar = tqdm( + total=n_iterations, + desc=f"Evaluation of {self.__class__.__name__} metric.", + ) for layer_name, random_layer_model in model.get_random_layer_generator( order=self.layer_order, seed=self.seed @@ -234,7 +237,9 @@ def __call__( pbar.close() if self.return_aggregate: - print("A 'return_aggregate' functionality is not implemented for this metric.") + print( + "A 'return_aggregate' functionality is not implemented for this metric." + ) self.all_results.append(self.last_results) @@ -313,7 +318,7 @@ def __call__( a_batch: Union[np.array, None], s_batch: Union[np.array, None] = None, *args, - **kwargs + **kwargs, ) -> List[float]: """ This implementation represents the main logic of the metric and makes the class object callable. @@ -405,7 +410,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a) in iterator: diff --git a/quantus/metrics/robustness_metrics.py b/quantus/metrics/robustness_metrics.py index 07ca0d49b..085e39d2e 100644 --- a/quantus/metrics/robustness_metrics.py +++ b/quantus/metrics/robustness_metrics.py @@ -216,7 +216,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a) in iterator: @@ -260,7 +262,7 @@ def __call__( b=a_perturbed.flatten(), c=x.flatten(), d=x_perturbed.flatten(), - **self.kwargs + **self.kwargs, ) if similarity > similarity_max: @@ -460,7 +462,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a) in iterator: @@ -707,7 +711,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a) in iterator: @@ -964,7 +970,9 @@ def __call__( iterator = enumerate(zip(x_batch_s, y_batch, a_batch)) else: iterator = tqdm( - enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s) + enumerate(zip(x_batch_s, y_batch, a_batch)), + total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) self.dx = np.prod(x_batch_s.shape[2:]) // self.nr_steps @@ -1046,7 +1054,9 @@ def __call__( self.last_results[ix] = sub_results if self.return_aggregate: - print("A 'return_aggregate' functionality is not implemented for this metric.") + print( + "A 'return_aggregate' functionality is not implemented for this metric." + ) self.all_results.append(self.last_results) @@ -1214,7 +1224,9 @@ def __call__( a_labels = np.array(list(map(self.discretise_func, a_batch_flat))) # Predict on input. - x_input = model.shape_input(x_batch, x_batch[0].shape, channel_first=True, batch=True) + x_input = model.shape_input( + x_batch, x_batch[0].shape, channel_first=True, batch=True + ) y_pred_classes = np.argmax( model.predict(x_input, softmax_act=True, **self.kwargs), axis=1 ).flatten() @@ -1226,6 +1238,7 @@ def __call__( iterator = tqdm( enumerate(zip(x_batch_s, y_batch, a_batch, a_labels)), total=len(x_batch_s), + desc=f"Evaluation of {self.__class__.__name__} metric.", ) for ix, (x, y, a, a_label) in iterator: @@ -1241,7 +1254,9 @@ def __call__( self.last_results.append(np.sum(pred_same_a == pred_a) / len(same_a)) if self.return_aggregate: - self.last_results = [self.aggregate_func(self.last_results) / len(self.last_results)] + self.last_results = [ + self.aggregate_func(self.last_results) / len(self.last_results) + ] self.all_results.append(self.last_results)