Skip to content

Commit

Permalink
Added progressbar to tdqm
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Jul 13, 2022
1 parent 930e8f2 commit 9468691
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 41 deletions.
6 changes: 5 additions & 1 deletion quantus/helpers/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def predict(self, x, **kwargs):
return pred.cpu().numpy()

def shape_input(
self, x: np.array, shape: Tuple[int, ...], channel_first: Optional[bool] = None, batch: bool = False
self,
x: np.array,
shape: Tuple[int, ...],
channel_first: Optional[bool] = None,
batch: bool = False,
):
"""
Reshape input into model expected input.
Expand Down
6 changes: 5 additions & 1 deletion quantus/helpers/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def predict(self, x, **kwargs):
return new_model(x, training=False).numpy()

def shape_input(
self, x: np.array, shape: Tuple[int, ...], channel_first: Optional[bool] = None, batch: bool = False
self,
x: np.array,
shape: Tuple[int, ...],
channel_first: Optional[bool] = None,
batch: bool = False,
):
"""
Reshape input into model expected input.
Expand Down
5 changes: 3 additions & 2 deletions quantus/helpers/warn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@ def deprecation_warnings(kwargs: dict = {}) -> None:
text = "argument 'nr_channels' is deprecated and will be removed in future versions.\n"
if "max_steps_per_input" in kwargs:
text = "argument 'max_steps_per_input' is deprecated and will be removed in future versions.\n"

if text != '\n': print(text)

if text != "\n":
print(text)
18 changes: 15 additions & 3 deletions quantus/metrics/axiomatic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -422,7 +426,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -640,7 +648,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down
1 change: 0 additions & 1 deletion quantus/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def interpret_scores(self) -> None:
"""
print(self.__init__.__doc__.split(".")[1].split("References")[0])
# print(self.__call__.__doc__.split("callable.")[1].split("Parameters")[0])

@property
def get_params(self) -> dict:
Expand Down
18 changes: 15 additions & 3 deletions quantus/metrics/complexity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -383,7 +387,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -575,7 +583,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down
77 changes: 61 additions & 16 deletions quantus/metrics/faithfulness_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -441,7 +445,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -672,7 +680,9 @@ def __call__(
iterator = enumerate(zip(x_batch_s, y_batch, a_batch))
else:
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s)
enumerate(zip(x_batch_s, y_batch, a_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a) in iterator:
Expand Down Expand Up @@ -920,7 +930,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -1157,7 +1171,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -1404,7 +1422,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -1650,7 +1672,9 @@ def __call__(
iterator = enumerate(zip(x_batch_s, y_batch, a_batch))
else:
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s)
enumerate(zip(x_batch_s, y_batch, a_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for sample, (x, y, a) in iterator:
Expand Down Expand Up @@ -1753,7 +1777,9 @@ def __call__(
self.last_results[sample] = sub_results

if self.return_aggregate:
print("A 'return_aggregate' functionality is not implemented for this metric.")
print(
"A 'return_aggregate' functionality is not implemented for this metric."
)

self.all_results.append(self.last_results)

Expand Down Expand Up @@ -1946,7 +1972,9 @@ def __call__(
iterator = enumerate(zip(x_batch_s, y_batch, a_batch))
else:
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s)
enumerate(zip(x_batch_s, y_batch, a_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for sample, (x, y, a) in iterator:
Expand Down Expand Up @@ -2035,7 +2063,9 @@ def __call__(
self.last_results[sample] = sub_results

if self.return_aggregate:
print("A 'return_aggregate' functionality is not implemented for this metric.")
print(
"A 'return_aggregate' functionality is not implemented for this metric."
)

self.all_results.append(self.last_results)

Expand Down Expand Up @@ -2256,7 +2286,9 @@ def __call__(
iterator = enumerate(zip(x_batch_s, y_batch, a_batch))
else:
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s)
enumerate(zip(x_batch_s, y_batch, a_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for sample, (x, y, a) in iterator:
Expand Down Expand Up @@ -2522,7 +2554,11 @@ def __call__(
if not self.display_progressbar:
iterator = zip(x_batch_s, y_batch, a_batch)
else:
iterator = tqdm(zip(x_batch_s, y_batch, a_batch), total=len(x_batch_s))
iterator = tqdm(
zip(x_batch_s, y_batch, a_batch),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for x, y, a in iterator:

Expand Down Expand Up @@ -2779,7 +2815,9 @@ def __call__(
iterator = enumerate(zip(x_batch_s, y_batch, a_batch))
else:
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch)), total=len(x_batch_s)
enumerate(zip(x_batch_s, y_batch, a_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

self.last_results = {str(k): 0 for k in self.percentages}
Expand Down Expand Up @@ -2818,7 +2856,9 @@ def __call__(

# Calculate accuracy for every number of most important pixels removed.
if self.return_aggregate:
print("A 'return_aggregate' functionality is not implemented for this metric.")
print(
"A 'return_aggregate' functionality is not implemented for this metric."
)
for k in self.last_results:
self.last_results[k] = self.last_results[k] / len(x_batch_s)

Expand Down Expand Up @@ -2974,7 +3014,9 @@ def __call__(
a_sim_matrix[dist_matrix <= self.threshold] = 1

# Predict on input.
x_input = model.shape_input(x_batch, x_batch[0].shape, channel_first=True, batch=True)
x_input = model.shape_input(
x_batch, x_batch[0].shape, channel_first=True, batch=True
)
y_pred_classes = np.argmax(
model.predict(x_input, softmax_act=True, **self.kwargs), axis=1
).flatten()
Expand All @@ -2986,6 +3028,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, a_sim_matrix)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, a_sim) in iterator:
Expand All @@ -3003,7 +3046,9 @@ def __call__(
)

if self.return_aggregate:
self.last_results = [self.aggregate_func(self.last_results) / len(self.last_results)]
self.last_results = [
self.aggregate_func(self.last_results) / len(self.last_results)
]

self.all_results.append(self.last_results)

Expand Down
10 changes: 9 additions & 1 deletion quantus/metrics/localisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand All @@ -214,7 +215,9 @@ def __call__(
else:
hit = bool(s[max_index])

self.last_results.append(hit) # ratio = np.sum(binary_mask) / float(binary_mask.shape[0] * binary_mask.shape[1])
self.last_results.append(
hit
) # ratio = np.sum(binary_mask) / float(binary_mask.shape[0] * binary_mask.shape[1])

if self.return_aggregate:
self.last_results = [self.aggregate_func(self.last_results)]
Expand Down Expand Up @@ -403,6 +406,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand Down Expand Up @@ -638,6 +642,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand Down Expand Up @@ -846,6 +851,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand Down Expand Up @@ -1052,6 +1058,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand Down Expand Up @@ -1249,6 +1256,7 @@ def __call__(
iterator = tqdm(
enumerate(zip(x_batch_s, y_batch, a_batch, s_batch)),
total=len(x_batch_s),
desc=f"Evaluation of {self.__class__.__name__} metric.",
)

for ix, (x, y, a, s) in iterator:
Expand Down
Loading

0 comments on commit 9468691

Please sign in to comment.