Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SmoothMPRT faster for torch (see implementation) #311

Open
annahedstroem opened this issue Nov 24, 2023 · 0 comments
Open

Make SmoothMPRT faster for torch (see implementation) #311

annahedstroem opened this issue Nov 24, 2023 · 0 comments

Comments

@annahedstroem
Copy link
Member

Description of the problem

  • Make SmoothMPRT faster with torch

Description of a solution

  • test and call this function within SmoothMPRT metric
    def explain_smooth_batch_torch(
        self,
        model: ModelInterface,
        x_batch: np.ndarray,
        y_batch: np.ndarray,
        std: float,
        **kwargs,
    ) -> np.ndarray:
        """
        Compute explanations, normalise and take absolute (if was configured so during metric initialization.)
        This method should primarily be used if you need to generate additional explanation
        in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach.
        It will do few things:
            - call model.shape_input (if ModelInterface instance was provided)
            - unwrap model (if ModelInterface instance was provided)
            - call explain_func
            - expand attribution channel

        Parameters
        -------
        model:
            A model that is subject to explanation.
        x_batch:
            A np.ndarray which contains the input data that are explained.
        y_batch:
            A np.ndarray which contains the output labels that are explained.
        std : float
            Standard deviation of the Gaussian noise.
        kwargs: optional, dict
            List of hyperparameters.

        Returns
        -------
        a_batch:
            Batch of explanations ready to be evaluated.
        """
        if not isinstance(x_batch, torch.Tensor):
            x_batch = torch.Tensor(x_batch).to(self.device)

        if not isinstance(y_batch, torch.Tensor):
            y_batch = torch.as_tensor(y_batch).to(self.device)

        a_batch_smooth = torch.zeros_like(x_batch)
        for n in range(self.nr_samples):
            # the last epsilon is defined as zero to compute the true output,
            # and have SmoothGrad w/ n_iter = 1 === gradient
            if n == self.nr_samples - 1:
                epsilon = torch.zeros_like(x_batch)
            else:
                epsilon = torch.randn_like(x_batch) * std

            a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs)

            if a_batch_smooth is None:
                a_batch_smooth = a_batch / self.nr_samples
            else:
                a_batch_smooth += a_batch / self.nr_samples

        return a_batch_smooth

Minimum acceptance criteria

  • Specify what is necessary for the issue to be closed.
  • @mentions of the person that is apt to review these changes e.g., @annahedstroem
@annahedstroem annahedstroem changed the title Add torch implementation to SmoothMPRT Make SmoothMPRT faster for torch (see implementation) Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant