Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in provided SGDLasso(), SGDRidge(), and SGDLinearRegression() classes? #910

Closed
th789 opened this issue Mar 23, 2022 · 2 comments
Closed

Comments

@th789
Copy link

th789 commented Mar 23, 2022

🐛 Bug

For the LimeBase example provided at https://captum.ai/api/lime.html (after making the updates in #908), using the provided SGDLasso(), SGDRidge(), and SGDLinearRegression() classes for the interpretable_model argument in LimeBase() leads to the following error message: "RuntimeError: expected scalar type Float but found Double" (more info below).

To Reproduce

import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel, SGDLasso, SGDRidge, SGDLinearRegression

class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 3)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

net = SimpleClassifier()

def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
      # kernel_width will be provided to attribute as a kwarg
      kernel_width = kwargs["kernel_width"]
      l2_dist = torch.norm(original_input - perturbed_input)
      return torch.exp(- (l2_dist**2) / (kernel_width**2))

def perturb_func(original_input, **kwargs):
      return original_input + torch.randn_like(original_input)

def to_interp_rep_transform(curr_sample, original_inp, **kwargs):
      return curr_sample

input = torch.randn(1, 5)

lime_attr = LimeBase(net,
                     interpretable_model=SGDLinearRegression(),
                     similarity_func=similarity_kernel,
                     perturb_func=perturb_func,
                     perturb_interpretable_space=False,
                     from_interp_rep_transform=None,
                     to_interp_rep_transform=to_interp_rep_transform)

attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) 

The code above runs if interpretable_model=SkLearnLinearModel("linear_model.Ridge"), but does not run if interpretable_model=SGDLasso(), interpretable_model=SGDRidge(), or interpretable_model=SGDLinearRegression().

image

Expected behavior

There should be no error message and attr_coefs should return the feature attributions.

Environment

 - Captum version: 0.5.0
 - Pytorch version: 1.10.0+cu111
 - OS (e.g., Linux): macOS
 - How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
 - Python version: 3.7.12
@vivekmig
Copy link
Contributor

Thanks for identifying this issue @th789 ! We have made the fixes in #938 , once this change is merged, the SGD linear models should work appropriate with Lime.

@th789
Copy link
Author

th789 commented May 12, 2022

Thank you @vivekmig!

facebook-github-bot pushed a commit that referenced this issue May 18, 2022
Summary:
This updates SGD linear models to work appropriately with Lime, addressing #910 . Particularly, this switches Lime interpretable model inputs / outputs from double to float and enables gradients when necessary. Also adds a unit test to Lime for testing with SGD linear models.

Pull Request resolved: #938

Reviewed By: NarineK

Differential Revision: D36331146

Pulled By: vivekmig

fbshipit-source-id: 84d7aecf293404f9ba0b14c48e8723e0e489b392
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants