Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possibly inverted heatmaps for Score-CAM for YOLOv5 #364

Open
semihcanturk opened this issue Nov 17, 2022 · 3 comments
Open

Possibly inverted heatmaps for Score-CAM for YOLOv5 #364

semihcanturk opened this issue Nov 17, 2022 · 3 comments

Comments

@semihcanturk
Copy link

Hi @jacobgil , I'm working on applying Score-CAM on YOLOv5 by implementing the YOLOBoxScoreTarget class (so this issue is closely related with #242). There are several issues I want to iron out before making a PR (e.g. YOLOv5 returns parseable Detection objects only when the input is not a Torch tensor - see ultralytics/yolov5#6726 - and my workaround is ad-hoc atm), but in my current implementation I find that the resulting heatmaps are inverted, e.g.:

Outputs for dog & cat example, unnormalized (L)/normalized (R):

scorecam scorecam_norm

Also observed similar trends with other images like the 5-dogs example. My YOLOBoxScoreTarget is as follows, it is essentially identical to FasterRCNNBoxScoreTarget, except it leverages parse_detections from the YOLOv5 notebook:

class YOLOBoxScoreTarget:
    """ For every original detected bounding box specified in "bounding boxes",
        assign a score on how the current bounding boxes match it,
            1. In IOU
            2. In the classification score.
        If there is not a large enough overlap, or the category changed,
        assign a score of 0.

        The total score is the sum of all the box scores.
    """

    def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
        self.labels = labels
        self.bounding_boxes = bounding_boxes
        self.iou_threshold = iou_threshold

    def __call__(self, model_outputs):
        boxes, colors, categories, names, confidences = parse_detections(model_outputs)
        boxes = torch.Tensor(boxes)
        output = torch.Tensor([0])
        if torch.cuda.is_available():
            output = output.cuda()
            boxes = boxes.cuda()

        if len(boxes) == 0:
            return output

        for box, label in zip(self.bounding_boxes, self.labels):
            box = torch.Tensor(box[None, :])
            if torch.cuda.is_available():
                box = box.cuda()

            ious = torchvision.ops.box_iou(box, boxes)
            index = ious.argmax()
            if ious[0, index] > self.iou_threshold and categories[index] == label:
                score = ious[0, index] + confidences[index]
                output = output + score
        return output

I also slightly altered get_cam_weights in score_cam.py to make it play nice with numpy inputs instead of torch tensors.

import torch
import tqdm
from .base_cam import BaseCAM
import numpy as np


class ScoreCAM(BaseCAM):
    def __init__(
            self,
            model,
            target_layers,
            use_cuda=False,
            reshape_transform=None):
        super(ScoreCAM, self).__init__(model,
                                       target_layers,
                                       use_cuda,
                                       reshape_transform=reshape_transform,
                                       uses_gradients=False)


        if len(target_layers) > 0:
            print("Warning: You are using ScoreCAM with target layers, "
                  "however ScoreCAM will ignore them.")

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        targets,
                        activations,
                        grads):
        with torch.no_grad():
            upsample = torch.nn.UpsamplingBilinear2d(
                size=input_tensor.shape[-2:])
            activation_tensor = torch.from_numpy(activations)
            if self.cuda:
                activation_tensor = activation_tensor.cuda()

            upsampled = upsample(activation_tensor)

            maxs = upsampled.view(upsampled.size(0),
                                  upsampled.size(1), -1).max(dim=-1)[0]
            mins = upsampled.view(upsampled.size(0),
                                  upsampled.size(1), -1).min(dim=-1)[0]
            maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
            upsampled = (upsampled - mins) / (maxs - mins)

            input_tensors = input_tensor[:, None,
                                         :, :] * upsampled[:, :, None, :, :]

            if hasattr(self, "batch_size"):
                BATCH_SIZE = self.batch_size
            else:
                BATCH_SIZE = 16

            scores = []
            for target, tensor in zip(targets, input_tensors):
                for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
                    batch = tensor[i: i + BATCH_SIZE, :]

                    # TODO: current solution to handle the issues with torch inputs, improve
                    batch = list(batch.numpy())
                    batch = [np.swapaxes((elt * 255).astype(np.uint8), 0, -1) for elt in batch]
                    outs = [self.model(b) for b in batch]
                    outputs = [target(o).cpu().item() for o in outs]

                    scores.extend(outputs)
            scores = torch.Tensor(scores)
            scores = scores.view(activations.shape[0], activations.shape[1])

            weights = torch.nn.Softmax(dim=-1)(scores).numpy()
            return weights

To me the implementations seem correct, and therefore I am not able to address why the resulting heatmaps seem inverted. Any ideas? I'd be happy to make a branch with a reproducible example to debug and potentially extend it to a PR. Please let me know.

@mahendra-gehlot
Copy link

Hi @semihcanturk , Any progress on this issue so far?

@syncsyncsync
Copy link

I ran into the same issue, but in my case, I simply converted the RGB format to the BGR format and it worked.

@semihcanturk
Copy link
Author

@mahendra-gehlot unfortunately I haven't dug deep into the code to figure out what's the underlying issue, but have noticed something interesting. It seems the heatmaps are correct for yolov5s, but inverted for yolov5n/m/l. Furthermore, this exclusively affects EigenCAM, ScoreCAM does not suffer from this for any model I've tried. Here are the outputs for n/s/m + ScoreCAM (n) for for example:

YOLOv5n YOLOv5s YOLOv5m YOLOv5n (ScoreCAM)
eigencam_yolov5n eigencam_yolov5s eigencam_yolov5m scorecam_yolov5n

This also relates with @syncsyncsync 's solution, which doesn't really explain why it works for yolov5s and ScoreCAM which uses the same image load/save pipeline. I have no intuition yet into why this is happening, but it may prove a good starting point for debugging (for which I may not find sufficient time in the near future I'm afraid).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants