Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with EfficientNet #32

Open
MyLtYkRiTiK opened this issue Oct 19, 2020 · 5 comments
Open

Problems with EfficientNet #32

MyLtYkRiTiK opened this issue Oct 19, 2020 · 5 comments

Comments

@MyLtYkRiTiK
Copy link

Hello!
Have you tested this code with EfficientNet (https://github.com/lukemelas/EfficientNet-PyTorch)?
I tried to do that, but I have very unrealistic visualizations with huge attention in one place, nearly identical for every image.
And I do not understand what the problem.

@kazuto1011
Copy link
Owner

kazuto1011 commented Oct 23, 2020

I think I did it :)

bull mastiff tiger cat
result_1st result_5th

code:

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import torch
from efficientnet_pytorch import EfficientNet
from PIL import Image
from torchvision import transforms

from grad_cam import GradCAM

if __name__ == "__main__":

    image_path = "cat_dog.png"
    device = "cpu"

    # Model from torchvision
    target_layer = "_blocks.15"
    model = EfficientNet.from_pretrained("efficientnet-b0")
    model.to(device)
    model.eval()

    # Images
    image = Image.open(image_path)
    raw_image = np.asarray(image)
    image = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )(image)
    images = torch.stack([image]).to(device)

    gcam = GradCAM(model=model)
    _, sorted_ids = gcam.forward(images)
    ids_1st = sorted_ids[:, [0]]  # 'bull_mastiff'
    ids_6th = sorted_ids[:, [5]]  # 'tiger_cat'

    # 1st round for the dog

    gcam.backward(ids=ids_1st)
    heatmap = gcam.generate(target_layer=target_layer)
    heatmap = heatmap.cpu().numpy().squeeze()
    heatmap = cm.turbo(heatmap)[..., :3] * 255.0
    heatmap = (heatmap.astype(np.float) + raw_image.astype(np.float)) / 2
    plt.imshow(np.uint8(heatmap))
    plt.show()

    # 2nd round for the cat

    gcam.backward(ids=ids_6th)
    heatmap = gcam.generate(target_layer=target_layer)
    heatmap = heatmap.cpu().numpy().squeeze()
    heatmap = cm.turbo(heatmap)[..., :3] * 255.0
    heatmap = (heatmap.astype(np.float) + raw_image.astype(np.float)) / 2
    plt.imshow(np.uint8(heatmap))
    plt.show()

@MyLtYkRiTiK
Copy link
Author

MyLtYkRiTiK commented Oct 23, 2020

Hello! Thank you for your code!
I tried it with my models and pretrained models and find the same strange thing that i see when tried to run your code on my own.

For example, all block layer for tench image, which is first class at imagenet: tench image
You can see strange red area in left top corner.
Not all images have the same area, but lots of them have.
Problem is the biggest for b4 model, slightly for others.
Do you know what can it be? Is it came from your visualization code or what?

For my trained model problem much worse and area is bigger :(

@MyLtYkRiTiK
Copy link
Author

MyLtYkRiTiK commented Oct 23, 2020

Code for reproducing:

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
import urllib.request
import io

from grad_cam import GradCAM

def softmax(x):
    """ applies softmax to an input x"""
    e_x = np.exp(x)
    return e_x / e_x.sum()

URL = 'https://cdn.shopify.com/s/files/1/0923/2396/files/1093844_10153046864430043_480555795_o_large.jpg?9023175636926334502'
with urllib.request.urlopen(URL) as url:
    f = io.BytesIO(url.read())

# Images
img = Image.open(f)

device='cpu'
model_name = "efficientnet-b4"
model = EfficientNet.from_pretrained(model_name)
model.to(device)
model.eval()


params_dict = {
    # (width_coefficient, depth_coefficient, resolution, dropout_rate)
    'efficientnet-b0': (1.0, 1.0, 224, 0.2),
    'efficientnet-b1': (1.0, 1.1, 240, 0.2),
    'efficientnet-b2': (1.1, 1.2, 260, 0.3),
    'efficientnet-b3': (1.2, 1.4, 300, 0.3),
    'efficientnet-b4': (1.4, 1.8, 380, 0.4),
    'efficientnet-b5': (1.6, 2.2, 456, 0.4),
    'efficientnet-b6': (1.8, 2.6, 528, 0.5),
    'efficientnet-b7': (2.0, 3.1, 600, 0.5),
    'efficientnet-b8': (2.2, 3.6, 672, 0.5),
    'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}


raw_image = np.asarray(
    transforms.Compose(
    [
       transforms.Resize(int(params_dict[model_name][2] * 1.05)),
            transforms.CenterCrop(params_dict[model_name][2])])(img))
image = transforms.Compose(
    [
        transforms.Resize(int(params_dict[model_name][2] * 1.05)),
            transforms.CenterCrop(params_dict[model_name][2]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)(img)
images = torch.stack([image]).to(device)

predicted_class = np.argmax(softmax(model(image[None, ...]).detach().numpy()))
print(f'Predicted class {predicted_class}')

gcam = GradCAM(model=model)
_, sorted_ids = gcam.forward(images)
ids_predicted = sorted_ids[:, [predicted_class]]

n_rows = len(model._blocks)//4 if len(model._blocks)%4==0 else len(model._blocks)//4+1
n_cols = 4

fig, ax = plt.subplots(n_rows, n_cols, figsize=(40,40))
gcam.backward(ids=ids_predicted)
layer = 0 
for row in range(n_rows):
    for col in range(n_cols):
        target_layer = f"_blocks.{layer}"
        heatmap = gcam.generate(target_layer=target_layer)
        heatmap = heatmap.cpu().numpy().squeeze()
        heatmap = cm.turbo(heatmap)[..., :3] * 255.0
        heatmap = (heatmap.astype(np.float) + raw_image.astype(np.float)) / 2
        ax[row,col].imshow(np.uint8(heatmap))
        ax[row,col].set_title(target_layer)
        if layer == len(model._blocks)-1:
            break
        else:
            layer += 1

plt.show()

@xiaohuaibaoguigui
Copy link

Code for reproducing:

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
import urllib.request
import io

from grad_cam import GradCAM

def softmax(x):
    """ applies softmax to an input x"""
    e_x = np.exp(x)
    return e_x / e_x.sum()

URL = 'https://cdn.shopify.com/s/files/1/0923/2396/files/1093844_10153046864430043_480555795_o_large.jpg?9023175636926334502'
with urllib.request.urlopen(URL) as url:
    f = io.BytesIO(url.read())

# Images
img = Image.open(f)

device='cpu'
model_name = "efficientnet-b4"
model = EfficientNet.from_pretrained(model_name)
model.to(device)
model.eval()


params_dict = {
    # (width_coefficient, depth_coefficient, resolution, dropout_rate)
    'efficientnet-b0': (1.0, 1.0, 224, 0.2),
    'efficientnet-b1': (1.0, 1.1, 240, 0.2),
    'efficientnet-b2': (1.1, 1.2, 260, 0.3),
    'efficientnet-b3': (1.2, 1.4, 300, 0.3),
    'efficientnet-b4': (1.4, 1.8, 380, 0.4),
    'efficientnet-b5': (1.6, 2.2, 456, 0.4),
    'efficientnet-b6': (1.8, 2.6, 528, 0.5),
    'efficientnet-b7': (2.0, 3.1, 600, 0.5),
    'efficientnet-b8': (2.2, 3.6, 672, 0.5),
    'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}


raw_image = np.asarray(
    transforms.Compose(
    [
       transforms.Resize(int(params_dict[model_name][2] * 1.05)),
            transforms.CenterCrop(params_dict[model_name][2])])(img))
image = transforms.Compose(
    [
        transforms.Resize(int(params_dict[model_name][2] * 1.05)),
            transforms.CenterCrop(params_dict[model_name][2]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)(img)
images = torch.stack([image]).to(device)

predicted_class = np.argmax(softmax(model(image[None, ...]).detach().numpy()))
print(f'Predicted class {predicted_class}')

gcam = GradCAM(model=model)
_, sorted_ids = gcam.forward(images)
ids_predicted = sorted_ids[:, [predicted_class]]

n_rows = len(model._blocks)//4 if len(model._blocks)%4==0 else len(model._blocks)//4+1
n_cols = 4

fig, ax = plt.subplots(n_rows, n_cols, figsize=(40,40))
gcam.backward(ids=ids_predicted)
layer = 0 
for row in range(n_rows):
    for col in range(n_cols):
        target_layer = f"_blocks.{layer}"
        heatmap = gcam.generate(target_layer=target_layer)
        heatmap = heatmap.cpu().numpy().squeeze()
        heatmap = cm.turbo(heatmap)[..., :3] * 255.0
        heatmap = (heatmap.astype(np.float) + raw_image.astype(np.float)) / 2
        ax[row,col].imshow(np.uint8(heatmap))
        ax[row,col].set_title(target_layer)
        if layer == len(model._blocks)-1:
            break
        else:
            layer += 1

plt.show()

there are some error ,when i run with it .
File "effcietn_cam.py", line 80, in
heatmap = cm.turbo(heatmap)[..., :3] * 255.0
AttributeError: module 'matplotlib.cm' has no attribute 'turbo'

@kazuto1011
Copy link
Owner

The turbo colormap is available from matplotlib 3.3.0. Please upgrade it or use jet instead.
https://matplotlib.org/3.3.3/users/whats_new.html#turbo-colormap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants