We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When using a model in the GPU, running tcav.interpret(...) throws a wrong-device error.
tcav.interpret(...)
Steps to reproduce the behavior:
tcav.compute_cavs()
./cav
import torch from torch import nn from torch.utils.data import Dataset, DataLoader from captum.concept import TCAV, Concept DEVICE = 'cuda' class MyModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 10, 10) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = nn.Flatten() self.classifier = nn.Linear(10, 1) def forward(self, images): # images shape: batch_size, 3, height, width x = self.conv(images) # shape: batch_size, 10, features-height, features-width x = self.pool(x) # shape: batch_size, 10, 1, 1 x = self.flatten(x) # shape: batch_size, 10 x = self.classifier(x) # shape: batch_size, 1 return x class DummyDataset(Dataset): def __init__(self, device='cpu'): super().__init__() self.device = device def __getitem__(self, idx): image = torch.zeros(3, 256, 256) return image.to(self.device) def __len__(self): return 10 model = MyModel().to(DEVICE) concept0 = Concept(0, 'concept0', DataLoader(DummyDataset(device=DEVICE), batch_size=10)) concept1 = Concept(1, 'concept1', DataLoader(DummyDataset(device=DEVICE), batch_size=10)) tcav = TCAV(model, layers='conv') inputs = torch.rand(7, 3, 256, 256).to(DEVICE) scores = tcav.interpret(inputs, [[concept0, concept1]])
The tcav.interpret(...) line throws: RuntimeError: Input, output and indices must be on the current device. The full stack:
RuntimeError: Input, output and indices must be on the current device
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-1-6775c2780feb> in <module> 39 40 inputs = torch.rand(7, 3, 256, 256).to(DEVICE) ---> 41 scores = tcav.interpret(inputs, [[concept0, concept1]]) ~/software/captum/captum/concept/_core/tcav.py in interpret(self, inputs, experimental_sets, target, additional_forward_args, processes, **kwargs) 597 cav_subset, 598 classes_subset, --> 599 experimental_subset_sorted, 600 ) 601 i += 1 ~/software/captum/captum/concept/_core/tcav.py in _tcav_sub_computation(self, scores, layer, attribs, cavs, classes, experimental_sets) 646 scores[concepts_key][layer] = { 647 "sign_count": torch.index_select( --> 648 sign_count_score[i, :], dim=0, index=new_ord 649 ), 650 "magnitude": torch.index_select( RuntimeError: Input, output and indices must be on the current device
The method interpret() should run without errors, in any of CPU or GPU (or the docs should state that only CPU is supported?)
interpret()
conda
pip
pip install -e ~/software/captum
TCAV()._tcav_sub_computation()
new_ord = torch.tensor([concept_ord[cls] for cls in cls_set], device=sign_count_score.device)
i.e. specifying the device param
device
The text was updated successfully, but these errors were encountered:
Support TCAV on cuda (#725)
3a697e3
Summary: Addresses the issues: #721 #719 #720 Pull Request resolved: #725 Reviewed By: bilalsal Differential Revision: D30356015 Pulled By: NarineK fbshipit-source-id: 010a5263bdfc33e8c4d3f9de523d9d3ba3969f49
fixed with #725
Sorry, something went wrong.
No branches or pull requests
🐛 Bug
When using a model in the GPU, running
tcav.interpret(...)
throws a wrong-device error.To Reproduce
Steps to reproduce the behavior:
tcav.compute_cavs()
in the CPU, which will save CAV vectors to./cav
The
tcav.interpret(...)
line throws:RuntimeError: Input, output and indices must be on the current device
.The full stack:
Expected behavior
The method
interpret()
should run without errors, in any of CPU or GPU (or the docs should state that only CPU is supported?)Environment
conda
,pip
, source): sourcepip install -e ~/software/captum
Additional context
TCAV()._tcav_sub_computation()
method to:i.e. specifying the
device
paramThe text was updated successfully, but these errors were encountered: