Skip to content

Commit

Permalink
Update ablation_cam.py (#440)
Browse files Browse the repository at this point in the history
Updated the comments for better understanding and suggested some syntax modification.
  • Loading branch information
naveentnj committed Dec 9, 2023
1 parent 09ac162 commit 25067d1
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions pytorch_grad_cam/ablation_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
""" Implementation of AblationCAM
https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
Ablate individual activations, and then measure the drop in the target score.
Ablate individual activations, and then measure the drop in the target scores.
In the current implementation, the target layer activations is cached, so it won't be re-computed.
However layers before it, if any, will not be cached.
Expand Down Expand Up @@ -88,8 +88,8 @@ def get_cam_weights(self,
[target(output).cpu().item() for target, output in zip(targets, outputs)])

# Replace the layer with the ablation layer.
# When we finish, we will replace it back, so the original model is
# unchanged.
# When we finish, we will replace it back, so the
# original model is unchanged.
ablation_layer = self.ablation_layer
replace_layer_recursive(self.model, target_layer, ablation_layer)

Expand Down Expand Up @@ -122,9 +122,9 @@ def get_cam_weights(self,
# Change the state of the ablation layer so it ablates the next channels.
# TBD: Move this into the ablation layer forward pass.
ablation_layer.set_next_batch(
input_batch_index=batch_index,
activations=self.activations,
num_channels_to_ablate=batch_tensor.size(0))
input_batch_index = batch_index,
activations = self.activations,
num_channels_to_ablate = batch_tensor.size(0))
score = [target(o).cpu().item()
for o in self.model(batch_tensor)]
new_scores.extend(score)
Expand All @@ -145,4 +145,5 @@ def get_cam_weights(self,

# Replace the model back to the original state
replace_layer_recursive(self.model, ablation_layer, target_layer)
# Returning the weights from new_scores
return weights

0 comments on commit 25067d1

Please sign in to comment.