You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For the LimeBase example provided at https://captum.ai/api/lime.html (after making the updates in #908), using the provided SGDLasso(), SGDRidge(), and SGDLinearRegression() classes for the interpretable_model argument in LimeBase() leads to the following error message: "RuntimeError: expected scalar type Float but found Double" (more info below).
To Reproduce
import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel, SGDLasso, SGDRidge, SGDLinearRegression
class SimpleClassifier(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.linear(x)
out = self.sigmoid(out)
return out
net = SimpleClassifier()
def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
# kernel_width will be provided to attribute as a kwarg
kernel_width = kwargs["kernel_width"]
l2_dist = torch.norm(original_input - perturbed_input)
return torch.exp(- (l2_dist**2) / (kernel_width**2))
def perturb_func(original_input, **kwargs):
return original_input + torch.randn_like(original_input)
def to_interp_rep_transform(curr_sample, original_inp, **kwargs):
return curr_sample
input = torch.randn(1, 5)
lime_attr = LimeBase(net,
interpretable_model=SGDLinearRegression(),
similarity_func=similarity_kernel,
perturb_func=perturb_func,
perturb_interpretable_space=False,
from_interp_rep_transform=None,
to_interp_rep_transform=to_interp_rep_transform)
attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
The code above runs if interpretable_model=SkLearnLinearModel("linear_model.Ridge"), but does not run if interpretable_model=SGDLasso(), interpretable_model=SGDRidge(), or interpretable_model=SGDLinearRegression().
Expected behavior
There should be no error message and attr_coefs should return the feature attributions.
Environment
- Captum version: 0.5.0
- Pytorch version: 1.10.0+cu111
- OS (e.g., Linux): macOS
- How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
- Python version: 3.7.12
The text was updated successfully, but these errors were encountered:
Thanks for identifying this issue @th789 ! We have made the fixes in #938 , once this change is merged, the SGD linear models should work appropriate with Lime.
Summary:
This updates SGD linear models to work appropriately with Lime, addressing #910 . Particularly, this switches Lime interpretable model inputs / outputs from double to float and enables gradients when necessary. Also adds a unit test to Lime for testing with SGD linear models.
Pull Request resolved: #938
Reviewed By: NarineK
Differential Revision: D36331146
Pulled By: vivekmig
fbshipit-source-id: 84d7aecf293404f9ba0b14c48e8723e0e489b392
🐛 Bug
For the LimeBase example provided at https://captum.ai/api/lime.html (after making the updates in #908), using the provided
SGDLasso()
,SGDRidge()
, andSGDLinearRegression()
classes for theinterpretable_model
argument inLimeBase()
leads to the following error message: "RuntimeError: expected scalar type Float but found Double" (more info below).To Reproduce
The code above runs if
interpretable_model=SkLearnLinearModel("linear_model.Ridge")
, but does not run ifinterpretable_model=SGDLasso()
,interpretable_model=SGDRidge()
, orinterpretable_model=SGDLinearRegression()
.Expected behavior
There should be no error message and
attr_coefs
should return the feature attributions.Environment
The text was updated successfully, but these errors were encountered: