Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LIME and Kernel SHAP to Captum #467

Open
vivekmig opened this issue Sep 15, 2020 · 5 comments
Open

Adding LIME and Kernel SHAP to Captum #467

vivekmig opened this issue Sep 15, 2020 · 5 comments

Comments

@vivekmig
Copy link
Contributor

Captum LIME API Design

Background

LIME is an algorithm which interprets a black-box model by training an interpretable model, such as a linear model, based on local perturbations around a particular input point to be explained. Coefficients or some other representation of the interpretable model can then be used to understand the behavior of the original model.

The interpretable model takes a binary vector as input which corresponds to presence or absence of interpretable components, such as super-pixels in an input image. Random binary vectors are chosen and the model output is computed on the corresponding inputs. The interpretable model is trained using these input-output pairs, weighted by the similarity between the perturbed input and original input.

Requires:

  • Model
  • X = Input
  • K = Interpretable Input Vector Length
  • f = Function mapping interpretable input vector {0, 1}^K to model input space
  • 𝜋 = Similarity Kernel, quantifies similarity between modified input and original input

Pseudocode:

Repeat N times:
    Randomly and uniformly sample a vector from distribution v = {0, 1}K
    Compute Model(f(v))
    Compute Similarity 𝜋(X, f(v))
    Add v, Model(f(v)) and 𝜋(X, f(v)) to interpretable model training set

Train interpretable regression model with:
    Features v, Expected Output Model(f(v)) and Similarity (Weight) 𝜋(X, f(v))
    
Return representation of interpretable model

A generalization of the LIME framework proposed here suggests some pieces can be made more customizable such as allowing non-binary interpretable embeddings, allowing sampling to either be in the interpretable input space or original input space, and allowing the interpretable model to also be trained with labels.

Design Considerations:

  • The LIME framework is very generic, allowing training any interpretable regression model based on the obtained samples. In addition, the similarity kernel and function mapping an interpretable input to main input can be generic. Ideally, an implementation of LIME in Captum should support all flexibilities in setting these parameters while also being easy to use for common use cases, such as regularized linear models.
  • Unlike other attribution methods, the return value is essentially some representation of a surrogate interpretable model, and not necessarily the shape of the input tensor. This breaks an assumption that holds true for all existing attribution methods.
  • An existing functionality in Captum methods such as Feature Ablation and Feature Permutation is a feature mask, which allows grouping input features. This relates closely to the function mapping a binary vector to an input; in many cases, this function will simply correspond to a grouping of input features (such as superpixels) and each element in the binary index will correspond to whether to include or exclude (set to a baseline) the particular feature group. It would be good to support a consistent structure with a feature mask and baseline since this use case is often desirable, while also leaving the option to apply a different function or structure.
  • KernelSHAP is a particular instantiation of LIME, setting a specific similarity kernel and a linear regression model with no regularization, which allows more efficient computation of Shapley values. Ideally, the implementation of LIME should allow for a small extension to easily implement KernelSHAP.

Proposed Captum API Design:

The Captum API includes a base class, which is completely generic and allows for implementations of generalized versions of surrogate interpretable model training,

The LIME implementation builds upon this generalized version with a design that closely mimics other attribution methods for users to easily try LIME and compare to existing attribution methods under certain assumptions on the function and interpretable model structure.

LimeBase

This is a generic class which allows training any surrogate interpretable model based on sample evaluations around a desired input. The constructor takes the model forward function, a sampling function, which can either return samples in an interpretable representation or the original input space, a transformation function, which defines the transformation between input and interpretable space, and a similarity function, which defines the weight on a perturbed input for training.

Constructor:

LimeBase(
     forward_func: Callable, 
     interpretable_model: Callable[[Tensor], Any]
     similarity_func: Callable[[TensorOrTupleOfTensors, TensorOrTupleOfTensors, Tensor, **kwargs], float],
     sampling_func: Callable[[TensorOrTupleOfTensors, **kwargs],TensorOrTupleOfTensors],
     sample_input_space: bool = False,
     transform_func: Callable)

Argument Descriptions:

  • forward_func - torch.nn.Module corresponding to forward function of the model for which attributions are desired.
  • interpretable_model - Function which trains an interpretable model and returns any representation of the trained interpretable model, which is returned when calling attribute. The function signature should be as follows:
train_interpretable_model(interpretable_inputs, weights, outputs, **kwargs) 
    → Returns some representation of trained interpretable model
  • similarity_func - Function which computes similarity between original input and perturbed input. This function takes the original input, the perturbed input and the interpretable representation of the perturbed input and returns a float quantifying the similarity. The function signature should be as follows:
similarity_func(original_inputs, pert_inputs, interpretable_pert_inputs **kwargs) 
    → Returns float corresponding to similarity
  • perturb_func - Function which samples perturbations to train interpretable surrogate model. Sampling can be done in either the interpretable input space or original input space, determined by the sample_input_space flag argument.
perturb_func(original_inputs, **kwargs) 
    → Returns sample of perturbed input, if sample_input_space is False, 
    this should be in the interpretable input space, if sample_input_space is True,
    this should be in the original input space
  • sample_input_space - This boolean argument defines whether sampling_func returns samples in the original input space (True) or in the interpretable input space (False). This also determines the type of transform_func necessary.
  • transform_func - Function defining transformation between interpretable input and original input space. If sample_input_space is True, since samples are in the original input space, this function should define the transformation from the original input space to the interpretable input space. If sample_input_space is False, since samples are in the interpretable input space, this function should define the transformation from the interpretable input space to the original input space.

attribute:

attribute(inputs, 
          target: TargetType,
          additional_forward_args: Any,
          n_samples: int,
          perturbations_per_eval: int,
          **kwargs) 

These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to each step of the process. Return value matches the return type of interpretable_model and can be any representation of the interpretable model.

LIME

The LIME class makes certain assumptions to the generic SurrogateModelAttribution class in order to match the structure of other attribution methods and allow easier switching between methods. transform_func is fixed to be defined by input_mask and baselines, which is very similar to other perturbation based methods in Captum. Particularly, the transformation between an interpretable binary vector of features to the original input space is defined by a mask which groups features in the map to indices, and a 1 corresponds to these features taking the value of inputs while a 0 or in the corresponding vector index requires taking the baseline values for these features. This transformation works nicely for grouping pixels in images, words in text model, etc. but may be limiting in some cases. Users can always override this by directly using the SurrogateModelAttribution class.

Also, as defined in the LIME paper and pseudocode, the sampling_func above is set to uniformly sample binary vectors in the interpretable input space, with length defined by the number of groups in the feature mask.

Constructor:

LIME(forward_func: Callable, 
     interpretable_model: Callable,
     similarity_func: Callable) 

Argument Descriptions:

  • forward_func - torch.nn.Module corresponding to model for which attributions are desired. This is consistent with all other attribution constructors in Captum.

  • interpretable_model - Function which trains an interpretable model and returns any representation of the trained interpretable model, which is returned when calling attribute. Note that the original LIME algorithm applies regularization (k-LASSO) in the training, which should be incorporated in this function. The function signature should be as follows:

  • train_interpretable_model(interpretable_inputs, weights, outputs, **kwargs)
    → Returns some representation of trained interpretable model

    A default model applying regularized linear regression will be provided.

  • similarity_func - Function which computes similarity between original input and transformed input. This function takes the original input, the perturbed input and the interpretable representation of the perturbed input and returns a float quantifying the similarity. The function signature should be as follows:

  • similarity_func(original_inputs, pert_inputs, interpretable_pert_inputs **kwargs)
    → Returns float corresponding to similarity

attribute:

attribute(inputs, 
          target: TargetType, 
          additional_forward_args: Any,
          n_samples: int,
          perturbations_per_eval: int,
          feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
          baselines: BaselineType, 
          return_input_shape: bool,
          **kwargs) 

Argument Descriptions:

These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to the custom functions. If return_input_shape is True, it is necessary for the interpretable model to return a tensor with a single value per input feature group and these values are scattered to the appropriate indices to return attributions matching the original input shape, consistent with other Captum methods. If return_input_shape is False, the return value matches the return type of interpretable_model and can be any representation of the interpretable model.

@NarineK NarineK added design doc enhancement New feature or request new algorithm and removed enhancement New feature or request labels Sep 15, 2020
@vivekmig vivekmig linked a pull request Sep 16, 2020 that will close this issue
@CleonWong
Copy link

CleonWong commented Feb 3, 2021

Is there a tutorial that covers how to use Captum's lime API for image classification?

Found this in the documentation, but it does not fully cover an example use case of the lime API.

@vivekmig
Copy link
Contributor Author

vivekmig commented Feb 3, 2021

Hi @CleonWong, yes, we are working on a Lime tutorial, it will be released soon :)

@CleonWong
Copy link

@vivekmig Perfect! Keep up the great work ◡̈

@caesar-one
Copy link

Any news about the LIME tutorial? Anyways, thanks for the great work! :)

@aobo-y
Copy link
Contributor

aobo-y commented Jul 23, 2021

@CleonWong @caesar-one we have prepared a tutorial. You can find it at https://captum.ai/tutorials/Image_and_Text_Classification_LIME

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants