diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index 61ced8fc0e..6b45ff1752 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -118,9 +118,13 @@ def __init__( loss_fn (Callable, optional): The loss function applied to model. `loss_fn` must be a "reduction" loss function that reduces the per-example losses in a batch, and returns a single scalar Tensor. Furthermore, - the reduction must be the *sum* of the per-example losses. For - instance, `nn.BCELoss(reduction="sum")` is acceptable, but - `nn.BCELoss(reduction="mean")` is *not* acceptable. + the reduction must be the *sum* or the *mean* of the per-example + losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable. + Also note that if `loss_fn` has no "reduction" attribute, + the implementation assumes that the reduction is the *sum* of the + per-example losses. If this is not the case, i.e. the reduction + is the *mean*, please set the "reduction" attribute of `loss_fn` + to "mean", i.e. `loss_fn.reduction = "mean"`. Default: None batch_size (int or None, optional): Batch size of the DataLoader created to iterate through `influence_src_dataset`, if it is a Dataset. @@ -156,12 +160,30 @@ def __init__( param.requires_grad = True assert loss_fn is not None, "loss function must not be none" + # If we are able to access the reduction used by `loss_fn`, we check whether - # the reduction is "sum", as required. - # TODO: allow loss_fn to be Callable - if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): - msg = "`loss_fn.reduction` must be `sum`." - assert loss_fn.reduction == "sum", msg + # the reduction is either 'sum' or 'mean', as required + if isinstance(loss_fn, Module) and hasattr( + loss_fn, "reduction" + ): # TODO: allow loss_fn to be Callable + assert loss_fn.reduction in [ + "sum", + "mean", + ], 'reduction for `loss_fn` must be "sum" or "mean"' + self.reduction_type = str(loss_fn.reduction) + else: + # if we are unable to access the reduction used by `loss_fn`, we warn + # the user about the assumptions we are making regarding the reduction + # used by `loss_fn` + warnings.warn( + 'Since `loss_fn` has no "reduction" attribute, the implementation ' + 'assumes that `loss_fn` is a "reduction" loss function that ' + "reduces the per-example losses by taking their *sum*. If " + "`loss_fn` instead reduces the per-example losses by taking their " + 'mean, please set the reduction attribute of `loss_fn` to "mean", ' + 'i.e. `loss_fn.reduction = "mean"`.' + ) + self.reduction_type = "sum" def _influence_batch_tracincp_fast( self, @@ -347,6 +369,11 @@ def _basic_computation_tracincp_fast( Args: influence_instance (TracInCPFast): A instance of TracInCPFast or its children. + We assume `influence_instance` has a `loss_fn` attribute, i.e. the loss + function applied to the output of the last fully-connected layer, as + well as a `reduction_type` attribute, which indicates whether `loss_fn` + reduces the per-example losses by using their mean or sum. The + `reduction_type` attribute must either be "mean" or "sum". inputs (Tuple of Any): A batch of examples, which could be a training batch or test batch, depending which method is the caller. Does not represent labels, which are passed as `targets`. The assumption is @@ -360,13 +387,23 @@ def _basic_computation_tracincp_fast( handle = influence_instance.final_fc_layer.register_forward_hook(_capture_inputs) out = influence_instance.model(*inputs) - assert influence_instance.loss_fn is not None + assert influence_instance.loss_fn is not None, "loss function is required" + assert influence_instance.reduction_type in [ + "sum", + "mean", + ], 'reduction_type must be either "mean" or "sum"' input_jacobians = _jacobian_loss_wrt_inputs( - influence_instance.loss_fn, out, targets, influence_instance.vectorize + influence_instance.loss_fn, + out, + targets, + influence_instance.vectorize, + influence_instance.reduction_type, ) handle.remove() _layer_inputs = layer_inputs[0] + assert len(input_jacobians.shape) == 2 + return input_jacobians, _layer_inputs diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 8fb1204c87..b14e109753 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -59,22 +59,85 @@ def _gradient_dot_product( def _jacobian_loss_wrt_inputs( - loss_fn: Union[Module, Callable], out: Tensor, targets: Tensor, vectorize: bool + loss_fn: Union[Module, Callable], + out: Tensor, + targets: Tensor, + vectorize: bool, + reduction_type: str, ) -> Tensor: r""" - Helper function to handle dealing with pytorch version differences for vectorized - jacobian calculation of loss wrt inputs. + Often, we have a loss function that computes a per-sample loss given a 1D tensor + input, and we want to calculate the jacobian of the loss w.r.t. that input. For + example, the input could be a length K tensor specifying the probability a given + sample belongs to each of K possible classes, and the loss function could be + cross-entropy loss. This function performs that calculation, but does so for a + *batch* of inputs. We create this helper function for two reasons: 1) to handle + differences between Pytorch versiosn for vectorized jacobian calculations, and + 2) this function does not accept the aforementioned per-sample loss function. + Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss + for a batch into a single loss. Using a "reduction" loss improves speed. + We will allow this reduction to either be the mean or sum of the per-sample losses, + and this function provides an uniform way to handle different possible reductions, + and also check if the reduction used is valid. Regardless of the reduction used, + this function returns the jacobian for the per-sample loss (for each sample in the + batch). + + Args: + loss_fn (torch.nn.Module or Callable or None): The loss function. If a library + defined loss function is provided, it would be expected to be a + torch.nn.Module. If a custom loss is provided, it can be either type, + but must behave as a library loss function would if `reduction='sum'` + or `reduction='mean'`. + out (tensor): This is a tensor that represents the batch of inputs to + `loss_fn`. In practice, this will be the output of a model; this is + why this argument is named `out`. `out` is a 2D tensor of shape + (batch size, model output dimensionality). We will call `loss_fn` via + `loss_fn(out, targets)`. + targets (tensor): The labels for the batch of inputs. + vectorize (bool): Flag to use experimental vectorize functionality for + `torch.autograd.functional.jacobian`. + reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn` + has the "reduction" attribute, we will check that they match. Can + only be "mean" or "sum". + + Returns: + jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly + defined by `loss_fn` and `reduction_type`) w.r.t each sample + in the batch represented by `out`. This is a 2D tensor, where the + first dimension is the batch dimension. """ + # TODO: allow loss_fn to be Callable + if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): + msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`" + + assert loss_fn.reduction != "none", msg0 + msg1 = ( + f"loss_fn.reduction ({loss_fn.reduction}) does not match" + f"reduction type ({reduction_type}). Please ensure they are" + " matching." + ) + assert loss_fn.reduction == reduction_type, msg1 + + if reduction_type != "sum" and reduction_type != "mean": + raise ValueError( + f"{reduction_type} is not a valid value for reduction_type. " + "Must be either 'sum' or 'mean'." + ) if torch.__version__ >= "1.8": - return torch.autograd.functional.jacobian( + input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out, vectorize=vectorize ) else: - return torch.autograd.functional.jacobian( + input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out ) + if reduction_type == "mean": + input_jacobians = input_jacobians * len(input_jacobians) + + return input_jacobians + def _load_flexible_state_dict( model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None diff --git a/tests/influence/_core/test_tracin_get_k_most_influential.py b/tests/influence/_core/test_tracin_get_k_most_influential.py index 3a3d0ce3a2..5aca3c1095 100644 --- a/tests/influence/_core/test_tracin_get_k_most_influential.py +++ b/tests/influence/_core/test_tracin_get_k_most_influential.py @@ -48,6 +48,8 @@ class TestTracInGetKMostInfluential(BaseTest): ), ("sum", DataInfluenceConstructor(TracInCPFast)), ("sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("mean", DataInfluenceConstructor(TracInCPFast)), + ("mean", DataInfluenceConstructor(TracInCPFastRandProj)), ] ], name_func=build_test_name_func(), diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 83e06e44f6..929d9ea5b2 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -50,6 +50,8 @@ def _test_tracin_regression_setup(self, tmpdir: str, features: int): ("sample_wise_trick", None, DataInfluenceConstructor(TracInCP)), ("check_idx", "sum", DataInfluenceConstructor(TracInCPFast)), ("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)), + ("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), ( "check_idx", "sum", @@ -168,6 +170,8 @@ def test_tracin_regression( ), ("sum", DataInfluenceConstructor(TracInCPFast)), ("sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("mean", DataInfluenceConstructor(TracInCPFast)), + ("mean", DataInfluenceConstructor(TracInCPFastRandProj)), ], name_func=build_test_name_func(), ) @@ -248,6 +252,8 @@ def _test_tracin_identity_regression_setup(self, tmpdir: str): ("sample_wise_trick", None, DataInfluenceConstructor(TracInCP)), ("check_idx", "sum", DataInfluenceConstructor(TracInCPFast)), ("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)), + ("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), ], name_func=build_test_name_func(), ) diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index 0b8ebc7fcc..85f68a5d6a 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -31,6 +31,7 @@ class TestTracInSelfInfluence(BaseTest): ), ), ("sum", DataInfluenceConstructor(TracInCPFast)), + ("mean", DataInfluenceConstructor(TracInCPFast)), ] ], name_func=build_test_name_func(args_to_skip=["reduction"]),