Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCAV: cannot run interpret() in cuda #721

Closed
pdpino opened this issue Jul 14, 2021 · 1 comment
Closed

TCAV: cannot run interpret() in cuda #721

pdpino opened this issue Jul 14, 2021 · 1 comment

Comments

@pdpino
Copy link

pdpino commented Jul 14, 2021

🐛 Bug

When using a model in the GPU, running tcav.interpret(...) throws a wrong-device error.

To Reproduce

Steps to reproduce the behavior:

  1. Work-around the TCAV: cannot run compute_cavs() in cuda #719 bug by running tcav.compute_cavs() in the CPU, which will save CAV vectors to ./cav
  2. Run the following code:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from captum.concept import TCAV, Concept

DEVICE = 'cuda'

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 10)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.classifier = nn.Linear(10, 1)
    def forward(self, images):
        # images shape: batch_size, 3, height, width
        x = self.conv(images) # shape: batch_size, 10, features-height, features-width
        x = self.pool(x) # shape: batch_size, 10, 1, 1
        x = self.flatten(x) # shape: batch_size, 10
        x = self.classifier(x) # shape: batch_size, 1
        return x

class DummyDataset(Dataset):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
    def __getitem__(self, idx):
        image = torch.zeros(3, 256, 256)
        return image.to(self.device)
    def __len__(self):
        return 10

model = MyModel().to(DEVICE)

concept0 = Concept(0, 'concept0', DataLoader(DummyDataset(device=DEVICE), batch_size=10))
concept1 = Concept(1, 'concept1', DataLoader(DummyDataset(device=DEVICE), batch_size=10))

tcav = TCAV(model, layers='conv')

inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
scores = tcav.interpret(inputs, [[concept0, concept1]])

The tcav.interpret(...) line throws: RuntimeError: Input, output and indices must be on the current device.
The full stack:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-6775c2780feb> in <module>
     39 
     40 inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
---> 41 scores = tcav.interpret(inputs, [[concept0, concept1]])

~/software/captum/captum/concept/_core/tcav.py in interpret(self, inputs, experimental_sets, target, additional_forward_args, processes, **kwargs)
    597                     cav_subset,
    598                     classes_subset,
--> 599                     experimental_subset_sorted,
    600                 )
    601                 i += 1

~/software/captum/captum/concept/_core/tcav.py in _tcav_sub_computation(self, scores, layer, attribs, cavs, classes, experimental_sets)
    646             scores[concepts_key][layer] = {
    647                 "sign_count": torch.index_select(
--> 648                     sign_count_score[i, :], dim=0, index=new_ord
    649                 ),
    650                 "magnitude": torch.index_select(

RuntimeError: Input, output and indices must be on the current device

Expected behavior

The method interpret() should run without errors, in any of CPU or GPU (or the docs should state that only CPU is supported?)

Environment

  • Captum / PyTorch Version: captum 0.4.0, torch 1.7.1+cu110
  • OS (e.g., Linux): Ubuntu 18.04.5
  • How you installed Captum / PyTorch (conda, pip, source): source
  • Build command you used (if compiling from source): pip install -e ~/software/captum
  • Python version: 3.6
  • CUDA/cuDNN version: cu110
  • GPU models and configuration: using a GPU RTX 3090
  • Any other relevant information: I'm running captum in the master branch, latest commit is f658185

Additional context

  • I was able to hot-fix it by changing this line in the TCAV()._tcav_sub_computation() method to:
new_ord = torch.tensor([concept_ord[cls] for cls in cls_set], device=sign_count_score.device)

i.e. specifying the device param

facebook-github-bot pushed a commit that referenced this issue Aug 17, 2021
Summary:
Addresses the issues: #721 #719 #720

Pull Request resolved: #725

Reviewed By: bilalsal

Differential Revision: D30356015

Pulled By: NarineK

fbshipit-source-id: 010a5263bdfc33e8c4d3f9de523d9d3ba3969f49
@NarineK
Copy link
Contributor

NarineK commented Aug 17, 2021

fixed with #725

@NarineK NarineK closed this as completed Aug 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants