Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KernelSHAP / Lime Improvements #619

Closed
wants to merge 8 commits into from

Conversation

vivekmig
Copy link
Contributor

  • Adds support for generators as perturb function for Lime with corresponding tests
  • Modifies KernelSHAP to sample based on categorical distributed on expected selected features and randomly sample vectors given expected number of selected features. This is theoretically equivalent to the previous approach of weighting randomly selected vectors, but this approach computationally scales better with larger numbers of features, since weights for larger numbers of features lead to arithmetic underflow.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@miguelmartin75 miguelmartin75 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass - mostly nits

return torch.tensor([similarities])


def kernel_shap_perturb_generator(
original_inp, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typehint? I assume original_inp is Union[Tensor, Tuple[Tensor, ...]]

captum/attr/_core/lime.py Show resolved Hide resolved
captum/attr/_core/lime.py Show resolved Hide resolved
captum/attr/_core/lime.py Show resolved Hide resolved
# weight to 100 (all other weights are < 1).
similarities = 100.0
# weight to 1000000 (all other weights are 1).
similarities = 1000000.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this would be a concern, but just incase we could add this as a default param to this method. With this, users can do a functools.partial to change the value just incase it is not sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a good point that this could need to be customized. For now, to avoid having too many parameters, I can make this an instance variable that advanced users can override on the object after creation, but we can make it a parameter later if necessary.

@@ -72,7 +73,7 @@ def __init__(
forward_func: Callable,
interpretable_model: Model,
similarity_func: Callable,
perturb_func: Callable,
perturb_func: Union[Callable],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing typehint in the Union?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks! Forgot to revert this change.

@vivekmig
Copy link
Contributor Author

Thanks for the review @miguelmartin75 ! Addressed comments.

Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this PR, @vivekmig!
I left couple nits. I think that it would be good to describe this trick a little bit in the code since the original approach in the paper is a bit different in terms of the kernel similarity function.

captum/attr/_core/kernel_shap.py Outdated Show resolved Hide resolved
captum/attr/_core/lime.py Show resolved Hide resolved
captum/attr/_core/kernel_shap.py Show resolved Hide resolved
threshold = torch.kthvalue(
rand_vals, num_features - num_selected_features
).values.item()
yield (rand_vals > threshold).to(device=device).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we, please, describe a little bit why are we following this logic instead of the default behavior in default_perturb_func:

def default_perturb_func(original_inp, **kwargs):

I think default_perturb_func is missing typehints too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add more documentation on this. There are a few other helper methods in Lime without type hints, so will add them together in a separate PR.

captum/attr/_core/kernel_shap.py Show resolved Hide resolved
captum/attr/_core/kernel_shap.py Show resolved Hide resolved
captum/attr/_core/lime.py Show resolved Hide resolved
tests/attr/test_lime.py Show resolved Hide resolved
tests/attr/test_kernel_shap.py Show resolved Hide resolved
captum/attr/_core/kernel_shap.py Show resolved Hide resolved
@vivekmig
Copy link
Contributor Author

Thanks for the review @NarineK ! Addressed comments.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Perturbations are sampled by the following process:
- Choose k (number of selected features), based on the distribution
p(k) = (M - 1) / (k * (M - k))
where M is the total number of features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: total number of features in the interpretable space ?

@NarineK NarineK self-requested a review February 25, 2021 00:24
Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation! Looks great! Maybe you could add in the description that each of the (M choose k) samples has equal prob of getting chosen thus we do this:

rand_vals = torch.randn(1, num_features)
            threshold = torch.kthvalue( ... 

If I remember your explanation correctly.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vivekmig merged this pull request in c825327.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants