Skip to content

Commit

Permalink
show one progress only for feature ablation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
aobo-y committed Mar 15, 2021
1 parent a1a3098 commit ccddf10
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 119 deletions.
84 changes: 55 additions & 29 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,43 +39,69 @@ def flush(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)


def _simple_progress_out(
iterable: Iterable, desc: str = None, total: int = None, file: TextIO = None
):
"""
Simple progress output used when tqdm is unavailable.
Same as tqdm, output to stderr channel
"""
cur = 0

if total is None and hasattr(iterable, "__len__"):
total = len(cast(Sized, iterable))

desc = desc + ": " if desc else ""

def _progress_str(cur):
if total:
class SimpleProgress:
def __init__(
self,
iterable: Iterable = None,
desc: str = None,
total: int = None,
file: TextIO = None,
):
"""
Simple progress output used when tqdm is unavailable.
Same as tqdm, output to stderr channel
"""
self.cur = 0

self.iterable = iterable
self.total = total
if total is None and hasattr(iterable, "__len__"):
self.total = len(cast(Sized, iterable))

self.desc = desc

file = DisableErrorIOWrapper(file if file else sys.stderr)
cast(TextIO, file)
self.file = file
self.closed = False

def __iter__(self):
if self.closed or not self.iterable:
return
self._refresh()
for it in self.iterable:
yield it
self.cur += 1
self._refresh()
self.close()

def _refresh(self):
progress_str = self.desc + ": " if self.desc else ""
if self.total:
# e.g., progress: 60% 3/5
return f"{desc}{100 * cur // total}% {cur}/{total}"
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
else:
# e.g., progress: .....
return f"{desc}{'.' * cur}"
progress_str += "." * self.cur

if not file:
file = sys.stderr
file = DisableErrorIOWrapper(file)
print("\r" + progress_str, end="", file=self.file)

print("\r" + _progress_str(cur), end="", file=file)
for it in iterable:
yield it
cur += 1
print("\r" + _progress_str(cur), end="", file=file)
def update(self, amount: int = 1):
if self.closed:
return
self.cur += amount
self._refresh()
if self.cur == self.total:
self.close()

print(file=file) # end with new line
def close(self):
if not self.closed:
print(file=self.file) # end with new line
self.closed = True


def progress(
iterable: Iterable,
iterable: Iterable = None,
desc: str = None,
total: int = None,
use_tqdm=True,
Expand All @@ -92,4 +118,4 @@ def progress(
"but tqdm is not installed. "
"Fall back to simply print out the progress."
)
return _simple_progress_out(iterable, desc=desc, total=total, file=file)
return SimpleProgress(iterable, desc=desc, total=total, file=file)
184 changes: 96 additions & 88 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,32 @@ def attribute(
for input in inputs
]

if show_progress:
feature_counts = self._get_feature_counts(
inputs, feature_mask, **kwargs
)
total_forwards = sum(
math.ceil(count / perturbations_per_eval)
for count in feature_counts
)
print('total_forwards:', total_forwards)
attr_progress = progress(
desc=f"{self.get_name()} attribution", total=total_forwards
)
attr_progress.update(0)

# Iterate through each feature tensor for ablation
for i in range(len(inputs)):
# Skip any empty input tensors
if torch.numel(inputs[i]) == 0:
continue

ablation_generator, ablation_meta = self._ablation_generator_with_meta(
for (
current_inputs,
current_add_args,
current_target,
current_mask,
) in self._ith_input_ablation_generator(
i,
inputs,
additional_forward_args,
Expand All @@ -325,26 +344,7 @@ def attribute(
feature_mask,
perturbations_per_eval,
**kwargs,
)

# optionally show progress of each input
if show_progress:
iter_steps = (
ablation_meta["num_features"] - ablation_meta["min_feature"]
)
iter_steps = math.ceil(iter_steps / perturbations_per_eval)
ablation_generator = progress(
ablation_generator,
desc=f"{self.get_name()} attribution of Inputs[{i}]",
total=iter_steps,
)

for (
current_inputs,
current_add_args,
current_target,
current_mask,
) in ablation_generator:
):
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = _run_forward(
Expand All @@ -353,6 +353,10 @@ def attribute(
current_target,
current_add_args,
)

if show_progress:
attr_progress.update()

# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the inputs
# tensor.
Expand All @@ -376,6 +380,9 @@ def attribute(
dim=0
)

if show_progress:
attr_progress.close()

# Divide total attributions by counts and return formatted attributions
if self.use_weights:
attrib = tuple(
Expand All @@ -387,7 +394,7 @@ def attribute(
_result = _format_output(is_inputs_tuple, attrib)
return _result

def _ablation_generator_with_meta(
def _ith_input_ablation_generator(
self,
i,
inputs,
Expand All @@ -399,12 +406,11 @@ def _ablation_generator_with_meta(
**kwargs,
):
"""
This method return an generator of ablation perturbations with its related meta
This method return an generator of ablation perturbations of the i-th input
Returns:
ablation_iter (generator): yields each perturbation to be evaluated
as a tuple (inputs, additional_forward_args, targets, mask).
meta (dict): meta data of this ablation {min_feature: int, feature_num: int}
"""
extra_args = {}
for key, value in kwargs.items():
Expand Down Expand Up @@ -441,75 +447,65 @@ def _ablation_generator_with_meta(
additional_args_repeated = additional_args
target_repeated = target

def _create_ablation_generator():
"""nested generator function to iterate perturbation features"""
num_features_processed = min_feature
while num_features_processed < num_features:
current_num_ablated_features = min(
perturbations_per_eval, num_features - num_features_processed
)
num_features_processed = min_feature
while num_features_processed < num_features:
current_num_ablated_features = min(
perturbations_per_eval, num_features - num_features_processed
)

# Store appropriate inputs and additional args based on batch size.
if current_num_ablated_features != perturbations_per_eval:
current_features = [
feature_repeated[
0 : current_num_ablated_features * num_examples
]
for feature_repeated in all_features_repeated
]
current_additional_args = (
_expand_additional_forward_args(
additional_args, current_num_ablated_features
)
if additional_args is not None
else None
)
current_target = _expand_target(
target, current_num_ablated_features
# Store appropriate inputs and additional args based on batch size.
if current_num_ablated_features != perturbations_per_eval:
current_features = [
feature_repeated[0 : current_num_ablated_features * num_examples]
for feature_repeated in all_features_repeated
]
current_additional_args = (
_expand_additional_forward_args(
additional_args, current_num_ablated_features
)
else:
current_features = all_features_repeated
current_additional_args = additional_args_repeated
current_target = target_repeated

# Store existing tensor before modifying
original_tensor = current_features[i]
# Construct ablated batch for features in range num_features_processed
# to num_features_processed + current_num_ablated_features and return
# mask with same size as ablated batch. ablated_features has dimension
# (current_num_ablated_features, num_examples, inputs[i].shape[1:])
# Note that in the case of sparse tensors, the second dimension
# may not necessarilly be num_examples and will match the first
# dimension of this tensor.
current_reshaped = current_features[i].reshape(
(current_num_ablated_features, -1) + current_features[i].shape[1:]
if additional_args is not None
else None
)
current_target = _expand_target(target, current_num_ablated_features)
else:
current_features = all_features_repeated
current_additional_args = additional_args_repeated
current_target = target_repeated

# Store existing tensor before modifying
original_tensor = current_features[i]
# Construct ablated batch for features in range num_features_processed
# to num_features_processed + current_num_ablated_features and return
# mask with same size as ablated batch. ablated_features has dimension
# (current_num_ablated_features, num_examples, inputs[i].shape[1:])
# Note that in the case of sparse tensors, the second dimension
# may not necessarilly be num_examples and will match the first
# dimension of this tensor.
current_reshaped = current_features[i].reshape(
(current_num_ablated_features, -1) + current_features[i].shape[1:]
)

ablated_features, current_mask = self._construct_ablated_input(
current_reshaped,
input_mask,
baseline,
num_features_processed,
num_features_processed + current_num_ablated_features,
**extra_args,
)
ablated_features, current_mask = self._construct_ablated_input(
current_reshaped,
input_mask,
baseline,
num_features_processed,
num_features_processed + current_num_ablated_features,
**extra_args,
)

# current_features[i] has dimension
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
# which can be provided to the model as input.
current_features[i] = ablated_features.reshape(
(-1,) + ablated_features.shape[2:]
)
yield tuple(
current_features
), current_additional_args, current_target, current_mask
# Replace existing tensor at index i.
current_features[i] = original_tensor
num_features_processed += current_num_ablated_features

return _create_ablation_generator(), dict(
min_feature=min_feature, num_features=num_features
)
# current_features[i] has dimension
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
# which can be provided to the model as input.
current_features[i] = ablated_features.reshape(
(-1,) + ablated_features.shape[2:]
)
yield tuple(
current_features
), current_additional_args, current_target, current_mask
# Replace existing tensor at index i.
current_features[i] = original_tensor
num_features_processed += current_num_ablated_features

def _construct_ablated_input(
self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs
Expand Down Expand Up @@ -552,6 +548,18 @@ def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
input_mask,
)

def _get_feature_counts(self, inputs, feature_mask, **kwargs):
""" return the numbers of input features """
if not feature_mask:
return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)

return tuple(
(mask.max() - mask.min()).item() + 1
if mask is not None
else (inp[0].numel() if inp.numel() else 0)
for inp, mask in zip(inputs, feature_mask)
)

@staticmethod
def _find_output_mode(
perturbations_per_eval: int,
Expand Down
4 changes: 4 additions & 0 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,7 @@ def _get_feature_range_and_mask(
) -> Tuple[int, int, None]:
feature_max = np.prod(kwargs["shift_counts"])
return 0, feature_max, None

def _get_feature_counts(self, inputs, feature_mask, **kwargs):
""" return the numbers of possible input features """
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])
Loading

0 comments on commit ccddf10

Please sign in to comment.