Skip to content

Commit

Permalink
Use custom test-time augmentations (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaushalya committed Dec 9, 2023
1 parent 25067d1 commit 82c88cc
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
import ttach as tta
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Optional
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
from pytorch_grad_cam.utils.image import scale_cam_image
Expand All @@ -15,7 +15,8 @@ def __init__(self,
use_cuda: bool = False,
reshape_transform: Callable = None,
compute_input_gradient: bool = False,
uses_gradients: bool = True) -> None:
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None) -> None:
self.model = model.eval()
self.target_layers = target_layers
self.cuda = use_cuda
Expand All @@ -24,6 +25,16 @@ def __init__(self,
self.reshape_transform = reshape_transform
self.compute_input_gradient = compute_input_gradient
self.uses_gradients = uses_gradients
if tta_transforms is None:
self.tta_transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
else:
self.tta_transforms = tta_transforms

self.activations_and_grads = ActivationsAndGradients(
self.model, target_layers, reshape_transform)

Expand Down Expand Up @@ -148,14 +159,8 @@ def forward_augmentation_smoothing(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
cams = []
for transform in transforms:
for transform in self.tta_transforms:
augmented_tensor = transform.augment_image(input_tensor)
cam = self.forward(augmented_tensor,
targets,
Expand Down

0 comments on commit 82c88cc

Please sign in to comment.