diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 3a80760c91..a904ce90c0 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -684,20 +684,48 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any: def _register_backward_hook( module: Module, hook: Callable, attr_obj: Any -) -> torch.utils.hooks.RemovableHandle: - # Special case for supporting output attributions for neuron methods - # This can be removed after deprecation of neuron output attributions - # for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop - # in v0.6.0 - if ( - hasattr(attr_obj, "skip_new_hook_layer") - and attr_obj.skip_new_hook_layer == module - ): - return module.register_backward_hook(hook) +) -> List[torch.utils.hooks.RemovableHandle]: + grad_out: Dict[device, Tensor] = {} - if _parse_version(torch.__version__) >= (1, 9, 0): - # Only supported for torch >= 1.9 - return module.register_full_backward_hook(hook) - else: - # Fallback for previous versions of PyTorch - return module.register_backward_hook(hook) + def forward_hook( + module: Module, + inp: Union[Tensor, Tuple[Tensor, ...]], + out: Union[Tensor, Tuple[Tensor, ...]], + ) -> None: + nonlocal grad_out + grad_out = {} + + def output_tensor_hook(output_grad: Tensor) -> None: + grad_out[output_grad.device] = output_grad + + if isinstance(out, tuple): + assert ( + len(out) == 1 + ), "Backward hooks not supported for module with >1 output" + out[0].register_hook(output_tensor_hook) + else: + out.register_hook(output_tensor_hook) + + def pre_hook(module, inp): + def input_tensor_hook(input_grad: Tensor): + if len(grad_out) == 0: + return + hook_out = hook(module, input_grad, grad_out[input_grad.device]) + + if hook_out is not None: + return hook_out[0] if isinstance(hook_out, tuple) else hook_out + + if isinstance(inp, tuple): + assert ( + len(inp) == 1 + ), "Backward hooks not supported for module with >1 input" + inp[0].register_hook(input_tensor_hook) + return inp[0].clone() + else: + inp.register_hook(input_tensor_hook) + return inp.clone() + + return [ + module.register_forward_pre_hook(pre_hook), + module.register_forward_hook(forward_hook), + ] diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 4d885ff749..09fe827c78 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -837,6 +837,8 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( parameters of the i-th layer, for the j-th member of the minibatch. """ with torch.autograd.set_grad_enabled(True): + inputs = tuple(inp.clone() for inp in inputs) + apply_gradient_requirements(inputs) sample_grad_wrapper = SampleGradientWrapper(model, layer_modules) try: sample_grad_wrapper.add_hooks() diff --git a/captum/_utils/sample_gradient.py b/captum/_utils/sample_gradient.py index 1c1b957ea9..660c0030a7 100644 --- a/captum/_utils/sample_gradient.py +++ b/captum/_utils/sample_gradient.py @@ -120,7 +120,7 @@ def _register_module_hooks(self, module: torch.nn.Module) -> None: self.forward_hooks.append( module.register_forward_hook(self._forward_hook_fn) ) - self.backward_hooks.append( + self.backward_hooks.extend( _register_backward_hook(module, self._backward_hook_fn, None) ) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index eaa67b10cd..89430e2c62 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -43,34 +43,6 @@ from torch.utils.hooks import RemovableHandle -# Check if module backward hook can safely be used for the module that produced -# this inputs / outputs mapping -def _check_valid_module(inputs_grad_fn, outputs) -> bool: - def is_output_cloned(output_fn, input_grad_fn) -> bool: - """ - Checks if the output has been cloned. This happens especially in case of - layer deeplift. - """ - return ( - output_fn[0].next_functions is not None - and output_fn[0].next_functions[0][0] == input_grad_fn - ) - - curr_fn = outputs.grad_fn - first_next = curr_fn.next_functions[0] - try: - # if `inputs` in the input to the network then the grad_fn is None and - # for that input backward_hook isn't computed. That's the reason why we - # need to check on `inputs_grad_fns[first_next[1]]` being None. - return ( - inputs_grad_fn is None - or first_next[0] == inputs_grad_fn - or is_output_cloned(first_next, inputs_grad_fn) - ) - except IndexError: - return False - - class DeepLift(GradientAttribution): r""" Implements DeepLIFT algorithm based on the following paper: @@ -112,10 +84,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. multiply_by_inputs (bool, optional): Indicates whether to factor model inputs' multiplier in the final attribution scores. In the literature this is also known as local vs global @@ -430,25 +399,6 @@ def _forward_pre_hook( """ inputs = _format_tensor_into_tuples(inputs) module.input = inputs[0].clone().detach() - module.input_grad_fns = inputs[0].grad_fn # type: ignore - - def tensor_backward_hook(grad): - if module.saved_grad is None: - raise RuntimeError( - """Module {} was detected as not supporting correctly module - backward hook. You should modify your hook to ignore the given - grad_inputs (recompute them by hand if needed) and save the - newly computed grad_inputs in module.saved_grad. See MaxPool1d - as an example.""".format( - module - ) - ) - return module.saved_grad - - # the hook is set by default but it will be used only for - # failure cases and will be removed otherwise - handle = inputs[0].register_hook(tensor_backward_hook) - module.input_hook = handle def _forward_hook( self, @@ -462,30 +412,13 @@ def _forward_hook( """ outputs = _format_tensor_into_tuples(outputs) module.output = outputs[0].clone().detach() - if not _check_valid_module(module.input_grad_fns, outputs[0]): - warnings.warn( - """An invalid module {} is detected. Saved gradients will - be used as the gradients of the module's input tensor. - See MaxPool1d as an example.""".format( - module - ) - ) - module.is_invalid = True # type: ignore - module.saved_grad = None # type: ignore - self.forward_handles.append(cast(RemovableHandle, module.input_hook)) - else: - module.is_invalid = False # type: ignore - # removing the hook if there is no failure case - cast(RemovableHandle, module.input_hook).remove() - del module.input_hook - del module.input_grad_fns def _backward_hook( self, module: Module, - grad_input: Union[Tensor, Tuple[Tensor, ...]], - grad_output: Union[Tensor, Tuple[Tensor, ...]], - ): + grad_input: Tensor, + grad_output: Tensor, + ) -> Tensor: r""" `grad_input` is the gradient of the neuron with respect to its input `grad_output` is the gradient of the neuron with respect to its output @@ -506,15 +439,14 @@ def _backward_hook( "Please, ensure that module is being used only once in the " "network.".format(module) ) - multipliers = tuple( - SUPPORTED_NON_LINEAR[type(module)]( - module, - module.input, - module.output, - grad_input, - grad_output, - eps=self.eps, - ) + + multipliers = SUPPORTED_NON_LINEAR[type(module)]( + module, + module.input, + module.output, + grad_input, + grad_output, + eps=self.eps, ) # remove all the properies that we set for the inputs and output del module.input @@ -545,10 +477,10 @@ def _register_hooks( # adds forward hook to leaf nodes that are non-linear forward_handle = module.register_forward_hook(self._forward_hook) pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook) - backward_handle = _register_backward_hook(module, self._backward_hook, self) + backward_handles = _register_backward_hook(module, self._backward_hook, self) self.forward_handles.append(forward_handle) self.forward_handles.append(pre_forward_handle) - self.backward_handles.append(backward_handle) + self.backward_handles.extend(backward_handles) def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None: for handle in extra_hooks_to_remove: @@ -627,9 +559,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API. + model (nn.Module): The reference to PyTorch model instance. multiply_by_inputs (bool, optional): Indicates whether to factor model inputs' multiplier in the final attribution scores. In the literature this is also known as local vs global @@ -941,7 +871,7 @@ def nonlinear( grad_input: Tensor, grad_output: Tensor, eps: float = 1e-10, -): +) -> Tensor: r""" grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij) grad_output: (dLoss / dlayer_out) @@ -949,18 +879,10 @@ def nonlinear( """ delta_in, delta_out = _compute_diffs(inputs, outputs) - new_grad_inp = list(grad_input) - - # supported non-linear modules take only single tensor as input hence accessing - # only the first element in `grad_input` and `grad_output` - new_grad_inp[0] = torch.where( - abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in + new_grad_inp = torch.where( + abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in ) - # If the module is invalid, save the newly computed gradients - # The original_grad_input will be overridden later in the Tensor hook - if module.is_invalid: - module.saved_grad = new_grad_inp[0] return new_grad_inp @@ -974,15 +896,14 @@ def softmax( ): delta_in, delta_out = _compute_diffs(inputs, outputs) - new_grad_inp = list(grad_input) grad_input_unnorm = torch.where( - abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in + abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in ) # normalizing - n = grad_input[0].numel() + n = grad_input.numel() # updating only the first half - new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n + new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n return new_grad_inp @@ -1073,7 +994,7 @@ def maxpool( module.ceil_mode, True, ) - grad_output_updated = grad_output[0] + grad_output_updated = grad_output unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk( unpool_func( grad_output_updated * delta_out, @@ -1089,20 +1010,7 @@ def maxpool( unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta]) - # If the module is invalid, we need to recompute the grad_input - if module.is_invalid: - original_grad_input = grad_input - grad_input = ( - unpool_func( - grad_output_updated, - indices, - module.kernel_size, - module.stride, - module.padding, - list(cast(torch.Size, module.input.shape)), - ), - ) - if grad_input[0].shape != inputs.shape: + if grad_input.shape != inputs.shape: raise AssertionError( "A problem occurred during maxpool modul's backward pass. " "The gradients with respect to inputs include only a " @@ -1118,13 +1026,7 @@ def maxpool( new_grad_inp = torch.where( abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in ) - # If the module is invalid, save the newly computed gradients - # The original_grad_input will be overridden later in the Tensor hook - if module.is_invalid: - module.saved_grad = new_grad_inp - return original_grad_input - else: - return (new_grad_inp,) + return new_grad_inp def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index e4b7a7e628..f8334e946d 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -79,8 +79,8 @@ def attribute( def _register_hooks(self, module: Module): if isinstance(module, torch.nn.ReLU): - hook = _register_backward_hook(module, self._backward_hook, self) - self.backward_hooks.append(hook) + hooks = _register_backward_hook(module, self._backward_hook, self) + self.backward_hooks.extend(hooks) def _backward_hook( self, @@ -121,9 +121,7 @@ def __init__(self, model: Module) -> None: r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place ReLU submodules; these are not - supported by the register_full_backward_hook PyTorch API. + model (nn.Module): The reference to PyTorch model instance. """ ModifiedReluGradientAttribution.__init__( self, model, use_relu_grad_output=False @@ -234,9 +232,7 @@ def __init__(self, model: Module) -> None: r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place ReLU submodules; these are not - supported by the register_full_backward_hook PyTorch API. + model (nn.Module): The reference to PyTorch model instance. """ ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True) diff --git a/captum/attr/_core/guided_grad_cam.py b/captum/attr/_core/guided_grad_cam.py index 113fc7379e..5c4f34cd86 100644 --- a/captum/attr/_core/guided_grad_cam.py +++ b/captum/attr/_core/guided_grad_cam.py @@ -51,10 +51,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place ReLU submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (torch.nn.Module): Layer for which GradCAM attributions are computed. Currently, only layers with a single tensor output are supported. diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index 9b8349e20c..9be0d9d989 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -69,10 +69,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (torch.nn.Module): Layer for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer's @@ -403,10 +400,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (torch.nn.Module): Layer for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer's diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index 0d813fda8c..0cc5beaa47 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -46,11 +46,8 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: any modification of it. Custom rules for a given layer need to be defined as attribute `module.rule` and need to be of type PropagationRule. - Model cannot contain any in-place nonlinear submodules; - these are not supported by the register_full_backward_hook - PyTorch API starting from PyTorch v1.9. - layer (torch.nn.Module or list of torch.nn.Module): Layer or layers + layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer's diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 433e87db01..6ec05b5b54 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -49,10 +49,7 @@ def __init__(self, model: Module) -> None: it. Custom rules for a given layer need to be defined as attribute `module.rule` and need to be of type PropagationRule. If no rule is specified for a layer, a pre-defined default rule for the module type - is used. Model cannot contain any in-place nonlinear submodules; - these are not supported by the register_full_backward_hook - PyTorch API starting from PyTorch v1.9. - + is used. """ GradientAttribution.__init__(self, model) self.model = model @@ -320,10 +317,10 @@ def _check_rules(self) -> None: def _register_forward_hooks(self) -> None: for layer in self.layers: if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: - backward_handle = _register_backward_hook( + backward_handles = _register_backward_hook( layer, PropagationRule.backward_hook_activation, self ) - self.backward_handles.append(backward_handle) + self.backward_handles.extend(backward_handles) else: forward_handle = layer.register_forward_hook( layer.rule.forward_hook # type: ignore diff --git a/captum/attr/_core/neuron/neuron_deep_lift.py b/captum/attr/_core/neuron/neuron_deep_lift.py index a7605d62dd..0a1bbba0c2 100644 --- a/captum/attr/_core/neuron/neuron_deep_lift.py +++ b/captum/attr/_core/neuron/neuron_deep_lift.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import warnings from typing import Any, Callable, cast, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -46,10 +45,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (torch.nn.Module): Layer for which neuron attributions are computed. Attributions for a particular neuron for the input or output of this layer are computed using the argument neuron_selector @@ -231,17 +227,6 @@ def attribute( >>> attribution = dl.attribute(input, (4,1,2)) """ dl = DeepLift(cast(Module, self.forward_func), self.multiplies_by_inputs) - if not attribute_to_neuron_input: - warnings.warn( - "Attribution to neuron output is no longer supported for" - " NeuronDeepLift and will be deprecated in Captum" - " 0.6.0 due to changes in PyTorch's full backward hook" - " behavior. To obtain attributions for a neuron's" - " output, please attribute with respect to the next layer's input" - ) - dl.skip_new_hook_layer = self.layer # type: ignore - else: - dl.skip_new_hook_layer = None # type: ignore dl.gradient_func = construct_neuron_grad_fn( self.layer, neuron_selector, @@ -290,10 +275,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place nonlinear submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (torch.nn.Module): Layer for which neuron attributions are computed. Attributions for a particular neuron for the input or output of this layer are computed using the argument neuron_selector @@ -470,17 +452,6 @@ def attribute( """ dl = DeepLiftShap(cast(Module, self.forward_func), self.multiplies_by_inputs) - if not attribute_to_neuron_input: - warnings.warn( - "Attribution to neuron output is no longer supported for" - " NeuronDeepLiftShap and will be deprecated in Captum" - " 0.6.0 due to changes in PyTorch's full backward hook" - " behavior. To obtain attributions for a neuron's" - " output, please attribute with respect to the next layer's input" - ) - dl.skip_new_hook_layer = self.layer # type: ignore - else: - dl.skip_new_hook_layer = None # type: ignore dl.gradient_func = construct_neuron_grad_fn( self.layer, neuron_selector, diff --git a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py index 574e28a947..c92ab2fe35 100644 --- a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py +++ b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import warnings from typing import Any, Callable, List, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -35,10 +34,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place ReLU submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (Module): Layer for which attributions are computed. Output size of attribute matches this layer's input or output dimensions, depending on whether we attribute to @@ -163,18 +159,6 @@ def attribute( >>> # index (4,1,2). >>> attribution = neuron_deconv.attribute(input, (4,1,2)) """ - if not attribute_to_neuron_input: - warnings.warn( - "Attribution to neuron output is no longer supported for" - " NeuronDeconvolution and will be deprecated in Captum" - " 0.6.0 due to changes in PyTorch's full backward hook" - " behavior. To obtain attributions for a neuron's" - " output, please attribute with respect to the next layer's input" - ) - self.deconv.skip_new_hook_layer = self.layer # type: ignore - else: - self.deconv.skip_new_hook_layer = None # type: ignore - self.deconv.gradient_func = construct_neuron_grad_fn( self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input ) @@ -207,10 +191,7 @@ def __init__( r""" Args: - model (nn.Module): The reference to PyTorch model instance. Model cannot - contain any in-place ReLU submodules; these are not - supported by the register_full_backward_hook PyTorch API - starting from PyTorch v1.9. + model (nn.Module): The reference to PyTorch model instance. layer (Module): Layer for which neuron attributions are computed. Attributions for a particular neuron in the output of this layer are computed using the argument neuron_selector @@ -332,18 +313,6 @@ def attribute( >>> # index (4,1,2). >>> attribution = neuron_gb.attribute(input, (4,1,2)) """ - if not attribute_to_neuron_input: - warnings.warn( - "Attribution to neuron output is no longer supported for" - " NeuronGuidedBackprop and will be deprecated in Captum" - " 0.6.0 due to changes in PyTorch's full backward hook" - " behavior. To obtain attributions for a neuron's" - " output, please attribute with respect to the next layer's input" - ) - self.guided_backprop.skip_new_hook_layer = self.layer # type: ignore - else: - self.guided_backprop.skip_new_hook_layer = None # type: ignore - self.guided_backprop.gradient_func = construct_neuron_grad_fn( self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input ) diff --git a/captum/attr/_utils/lrp_rules.py b/captum/attr/_utils/lrp_rules.py index edacdef004..2e01b0afac 100644 --- a/captum/attr/_utils/lrp_rules.py +++ b/captum/attr/_utils/lrp_rules.py @@ -33,16 +33,12 @@ def forward_hook(self, module, inputs, outputs): @staticmethod def backward_hook_activation(module, grad_input, grad_output): """Backward hook to propagate relevance over non-linear activations.""" - if ( - isinstance(grad_input, tuple) - and isinstance(grad_output, tuple) - and len(grad_input) > len(grad_output) - ): - # Adds any additional elements of grad_input if applicable - # This occurs when registering a backward hook on nn.Dropout - # modules, which has an additional element of None in - # grad_input - return grad_output + grad_input[len(grad_output) :] + # replace_out is set in _backward_hook_input, this is necessary + # due to 2 tensor hooks on the same tensor + if hasattr(grad_output, "replace_out"): + hook_out = grad_output.replace_out + del grad_output.replace_out + return hook_out return grad_output def _create_backward_hook_input(self, inputs): @@ -53,6 +49,10 @@ def _backward_hook_input(grad): self.relevance_input[device] = relevance.data else: self.relevance_input[device].append(relevance.data) + + # replace_out is needed since two hooks are set on the same tensor + # The output of this hook is needed in backward_hook_activation + grad.replace_out = relevance return relevance return _backward_hook_input diff --git a/tests/attr/layer/test_layer_deeplift.py b/tests/attr/layer/test_layer_deeplift.py index ce64de2f3b..9ca49cc880 100644 --- a/tests/attr/layer/test_layer_deeplift.py +++ b/tests/attr/layer/test_layer_deeplift.py @@ -25,7 +25,7 @@ class TestDeepLift(BaseTest): def test_relu_layer_deeplift(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing() layer_dl = LayerDeepLift(model, model.relu) @@ -39,7 +39,7 @@ def test_relu_layer_deeplift(self) -> None: assert_delta(self, delta) def test_relu_layer_deeplift_wo_mutliplying_by_inputs(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing() layer_dl = LayerDeepLift(model, model.relu, multiply_by_inputs=False) @@ -83,7 +83,7 @@ def test_relu_layer_deeplift_add_args(self) -> None: assert_delta(self, delta) def test_linear_layer_deeplift(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing() layer_dl = LayerDeepLift(model, model.l3) @@ -103,7 +103,7 @@ def test_relu_deeplift_with_custom_attr_func(self) -> None: self._relu_custom_attr_func_assert(attr_method, inputs, baselines, [[2.0]]) def test_inplace_maxpool_relu_with_custom_attr_func(self) -> None: - model = BasicModel_MaxPool_ReLU(inplace=False) + model = BasicModel_MaxPool_ReLU(inplace=True) inp = torch.tensor([[[1.0, 2.0, -4.0], [-3.0, -2.0, -1.0]]]) dl = LayerDeepLift(model, model.maxpool) @@ -116,7 +116,7 @@ def custom_att_func(mult, inp, baseline): dl.attribute(inp, custom_attribution_func=custom_att_func) def test_linear_layer_deeplift_batch(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) _, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing() x1 = torch.tensor( [[-10.0, 1.0, -5.0], [-10.0, 1.0, -5.0], [-10.0, 1.0, -5.0]], @@ -197,7 +197,7 @@ def test_relu_layer_deepliftshap_multiple_output(self) -> None: assert_delta(self, delta) def test_linear_layer_deepliftshap(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) ( inputs, baselines, diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index e4ad951ace..8628a0f4ab 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -88,6 +88,19 @@ def test_lrp_simple_repeat_attributions(self) -> None: output_after = model(inputs) assertTensorAlmostEqual(self, output, output_after) + def test_lrp_simple_inplaceReLU(self) -> None: + model_default, inputs = _get_simple_model() + model_inplace, _ = _get_simple_model(inplace=True) + for model in [model_default, model_inplace]: + model.eval() + model.linear.rule = EpsilonRule() + model.linear2.rule = EpsilonRule() + lrp_default = LayerLRP(model_default, model_default.linear2) + lrp_inplace = LayerLRP(model_inplace, model_inplace.linear2) + relevance_default = lrp_default.attribute(inputs, attribute_to_layer_input=True) + relevance_inplace = lrp_inplace.attribute(inputs, attribute_to_layer_input=True) + assertTensorAlmostEqual(self, relevance_default[0], relevance_inplace[0]) + def test_lrp_simple_tanh(self) -> None: class Model(nn.Module): def __init__(self) -> None: diff --git a/tests/attr/neuron/test_neuron_deeplift.py b/tests/attr/neuron/test_neuron_deeplift.py index bfe7b55d0e..8d1435847c 100644 --- a/tests/attr/neuron/test_neuron_deeplift.py +++ b/tests/attr/neuron/test_neuron_deeplift.py @@ -2,7 +2,6 @@ from __future__ import print_function -import copy from typing import Tuple, Union import torch @@ -181,11 +180,10 @@ def test_lin_maxpool_lin_classification(self) -> None: baselines = torch.tensor([[1, 2, 3, 9], [4, 8, 6, 7]]).float() model = LinearMaxPoolLinearModel() - model_copy = copy.deepcopy(model) ndl = NeuronDeepLift(model, model.pool1) attr = ndl.attribute(inputs, neuron_selector=(0), baselines=baselines) - ndl2 = NeuronDeepLift(model_copy, model_copy.lin2) + ndl2 = NeuronDeepLift(model, model.lin2) attr2 = ndl2.attribute( inputs, neuron_selector=(0), @@ -197,12 +195,11 @@ def test_lin_maxpool_lin_classification(self) -> None: def test_convnet_maxpool2d_classification(self) -> None: inputs = 100 * torch.randn(2, 1, 10, 10) model = BasicModel_ConvNet() - model_copy = copy.deepcopy(model) ndl = NeuronDeepLift(model, model.pool1) attr = ndl.attribute(inputs, neuron_selector=(0, 0, 0)) - ndl2 = NeuronDeepLift(model_copy, model_copy.conv2) + ndl2 = NeuronDeepLift(model, model.conv2) attr2 = ndl2.attribute( inputs, neuron_selector=(0, 0, 0), attribute_to_neuron_input=True ) @@ -212,12 +209,11 @@ def test_convnet_maxpool2d_classification(self) -> None: def test_convnet_maxpool3d_classification(self) -> None: inputs = 100 * torch.randn(2, 1, 10, 10, 10) model = BasicModel_ConvNet_MaxPool3d() - model_copy = copy.deepcopy(model) ndl = NeuronDeepLift(model, model.pool1) attr = ndl.attribute(inputs, neuron_selector=(0, 0, 0, 0)) - ndl2 = NeuronDeepLift(model_copy, model_copy.conv2) + ndl2 = NeuronDeepLift(model, model.conv2) attr2 = ndl2.attribute( inputs, neuron_selector=(0, 0, 0, 0), attribute_to_neuron_input=True ) diff --git a/tests/attr/test_deconvolution.py b/tests/attr/test_deconvolution.py index 8b991c6e54..9b54b7b9d4 100644 --- a/tests/attr/test_deconvolution.py +++ b/tests/attr/test_deconvolution.py @@ -2,7 +2,6 @@ from __future__ import print_function -import copy import unittest from typing import Any, Tuple, Union @@ -126,8 +125,7 @@ def _deconv_matching_assert( test_input: TensorOrTupleOfTensorsGeneric, ) -> None: out = model(test_input) - model_copy = copy.deepcopy(model) - attrib = Deconvolution(model_copy) + attrib = Deconvolution(model) self.assertFalse(attrib.multiplies_by_inputs) neuron_attrib = NeuronDeconvolution(model, output_layer) for i in range(out.shape[1]): diff --git a/tests/attr/test_deeplift_basic.py b/tests/attr/test_deeplift_basic.py index 70e2c82510..09499f4348 100644 --- a/tests/attr/test_deeplift_basic.py +++ b/tests/attr/test_deeplift_basic.py @@ -103,8 +103,36 @@ def test_relu_linear_deeplift(self) -> None: # expected = [[[0.0, 0.0]], [[6.0, 2.0]]] self._deeplift_assert(model, DeepLift(model), inputs, baselines) + def test_relu_linear_deeplift_compare_inplace(self) -> None: + model1 = ReLULinearModel(inplace=True) + x1 = torch.tensor([[-10.0, 1.0, -5.0], [2.0, 3.0, 4.0]], requires_grad=True) + x2 = torch.tensor([[3.0, 3.0, 1.0], [2.3, 5.0, 4.0]], requires_grad=True) + inputs = (x1, x2) + attributions1 = DeepLift(model1).attribute(inputs) + + model2 = ReLULinearModel() + attributions2 = DeepLift(model2).attribute(inputs) + assertTensorAlmostEqual(self, attributions1[0], attributions2[0]) + assertTensorAlmostEqual(self, attributions1[1], attributions2[1]) + + def test_relu_linear_deepliftshap_compare_inplace(self) -> None: + model1 = ReLULinearModel(inplace=True) + x1 = torch.tensor([[-10.0, 1.0, -5.0], [2.0, 3.0, 4.0]], requires_grad=True) + x2 = torch.tensor([[3.0, 3.0, 1.0], [2.3, 5.0, 4.0]], requires_grad=True) + inputs = (x1, x2) + b1 = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + b2 = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + baselines = (b1, b2) + + attributions1 = DeepLiftShap(model1).attribute(inputs, baselines) + + model2 = ReLULinearModel() + attributions2 = DeepLiftShap(model2).attribute(inputs, baselines) + assertTensorAlmostEqual(self, attributions1[0], attributions2[0]) + assertTensorAlmostEqual(self, attributions1[1], attributions2[1]) + def test_relu_linear_deeplift_batch(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) x1 = torch.tensor([[-10.0, 1.0, -5.0], [2.0, 3.0, 4.0]], requires_grad=True) x2 = torch.tensor([[3.0, 3.0, 1.0], [2.3, 5.0, 4.0]], requires_grad=True) @@ -170,7 +198,7 @@ def test_relu_deepliftshap_multi_ref(self) -> None: self._deeplift_assert(model, DeepLiftShap(model), inputs, baselines) def test_relu_deepliftshap_baselines_as_func(self) -> None: - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) x1 = torch.tensor([[-10.0, 1.0, -5.0]]) x2 = torch.tensor([[3.0, 3.0, 1.0]]) @@ -218,7 +246,7 @@ def custom_attr_func( ) -> Tuple[Tensor, ...]: return tuple(multiplier * 0.0 for multiplier in multipliers) - model = ReLULinearModel(inplace=False) + model = ReLULinearModel(inplace=True) x1 = torch.tensor([[-10.0, 1.0, -5.0]]) x2 = torch.tensor([[3.0, 3.0, 1.0]]) b1 = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) diff --git a/tests/attr/test_guided_backprop.py b/tests/attr/test_guided_backprop.py index 46703c0184..dcadee662d 100644 --- a/tests/attr/test_guided_backprop.py +++ b/tests/attr/test_guided_backprop.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import copy import unittest from typing import Any, List, Tuple, Union @@ -151,8 +150,7 @@ def _guided_backprop_matching_assert( test_input: TensorOrTupleOfTensorsGeneric, ): out = model(test_input) - model_copy = copy.deepcopy(model) - attrib = GuidedBackprop(model_copy) + attrib = GuidedBackprop(model) self.assertFalse(attrib.multiplies_by_inputs) neuron_attrib = NeuronGuidedBackprop(model, output_layer) for i in range(out.shape[1]): diff --git a/tests/attr/test_guided_grad_cam.py b/tests/attr/test_guided_grad_cam.py index 11db183459..002fee5d86 100644 --- a/tests/attr/test_guided_grad_cam.py +++ b/tests/attr/test_guided_grad_cam.py @@ -44,7 +44,7 @@ def test_simple_multi_input_conv(self) -> None: self._guided_grad_cam_test_assert(net, net.conv1, (inp, inp2), (ex, ex)) def test_simple_multi_input_relu_input(self) -> None: - net = BasicModel_ConvNet_One_Conv(inplace=False) + net = BasicModel_ConvNet_One_Conv(inplace=True) inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) inp2 = torch.ones((1, 1, 4, 4)) ex = [ @@ -61,6 +61,22 @@ def test_simple_multi_input_relu_input(self) -> None: net, net.relu1, (inp, inp2), (ex, ex), attribute_to_layer_input=True ) + def test_simple_multi_input_conv_inplace(self) -> None: + net = BasicModel_ConvNet_One_Conv(inplace=True) + inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) + inp2 = torch.ones((1, 1, 4, 4)) + ex = [ + [ + [ + [14.5, 29.0, 38.0, 19.0], + [29.0, 58.0, 76.0, 38.0], + [65.0, 130.0, 148.0, 74.0], + [32.5, 65.0, 74.0, 37.0], + ] + ] + ] + self._guided_grad_cam_test_assert(net, net.conv1, (inp, inp2), (ex, ex)) + def test_improper_dims_multi_input_conv(self) -> None: net = BasicModel_ConvNet_One_Conv() inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) diff --git a/tests/attr/test_lrp.py b/tests/attr/test_lrp.py index e946144613..f42493b6bb 100644 --- a/tests/attr/test_lrp.py +++ b/tests/attr/test_lrp.py @@ -125,6 +125,19 @@ def test_lrp_simple_repeat_attributions(self) -> None: output_after = model(inputs) assertTensorAlmostEqual(self, output, output_after) + def test_lrp_simple_inplaceReLU(self) -> None: + model_default, inputs = _get_simple_model() + model_inplace, _ = _get_simple_model(inplace=True) + for model in [model_default, model_inplace]: + model.eval() + model.linear.rule = EpsilonRule() # type: ignore + model.linear2.rule = EpsilonRule() # type: ignore + lrp_default = LRP(model_default) + lrp_inplace = LRP(model_inplace) + relevance_default = lrp_default.attribute(inputs) + relevance_inplace = lrp_inplace.attribute(inputs) + assertTensorAlmostEqual(self, relevance_default, relevance_inplace) + def test_lrp_simple_tanh(self) -> None: class Model(nn.Module): def __init__(self) -> None: diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 39801b47f1..1f63a07f3f 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -4,6 +4,7 @@ from typing import Callable, Tuple import torch +from captum._utils.gradient import apply_gradient_requirements from captum._utils.sample_gradient import ( _reset_sample_grads, SampleGradientWrapper, @@ -74,6 +75,7 @@ def _compare_sample_grads_per_sample( ): wrapper = SampleGradientWrapper(model) wrapper.add_hooks() + apply_gradient_requirements(inputs) out = model(*inputs) wrapper.compute_param_sample_gradients(loss_fn(out), loss_type) @@ -120,6 +122,7 @@ def test_sample_grads_layer_modules(self): # compute sample grads wrapper = SampleGradientWrapper(model, layer_modules) wrapper.add_hooks() + apply_gradient_requirements(inp) out = model(*inp) wrapper.compute_param_sample_gradients(torch.sum(out), "sum")