Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow "mean" reduction for TracInCPFast, TracInCPFastRandProj #913

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions captum/influence/_core/tracincp_fast_rand_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@ def __init__(
loss_fn (Callable, optional): The loss function applied to model. `loss_fn`
must be a "reduction" loss function that reduces the per-example
losses in a batch, and returns a single scalar Tensor. Furthermore,
the reduction must be the *sum* of the per-example losses. For
instance, `nn.BCELoss(reduction="sum")` is acceptable, but
`nn.BCELoss(reduction="mean")` is *not* acceptable.
the reduction must be the *sum* or the *mean* of the per-example
losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable.
Also note that if `loss_fn` has no "reduction" attribute,
the implementation assumes that the reduction is the *sum* of the
per-example losses. If this is not the case, i.e. the reduction
is the *mean*, please set the "reduction" attribute of `loss_fn`
to "mean", i.e. `loss_fn.reduction = "mean"`.
Default: None
batch_size (int or None, optional): Batch size of the DataLoader created to
iterate through `influence_src_dataset`, if it is a Dataset.
Expand Down Expand Up @@ -156,12 +160,30 @@ def __init__(
param.requires_grad = True

assert loss_fn is not None, "loss function must not be none"

# If we are able to access the reduction used by `loss_fn`, we check whether
# the reduction is "sum", as required.
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
msg = "`loss_fn.reduction` must be `sum`."
assert loss_fn.reduction == "sum", msg
# the reduction is either 'sum' or 'mean', as required
if isinstance(loss_fn, Module) and hasattr(
loss_fn, "reduction"
): # TODO: allow loss_fn to be Callable
assert loss_fn.reduction in [
"sum",
"mean",
], 'reduction for `loss_fn` must be "sum" or "mean"'
self.reduction_type = str(loss_fn.reduction)
else:
# if we are unable to access the reduction used by `loss_fn`, we warn
# the user about the assumptions we are making regarding the reduction
# used by `loss_fn`
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, the implementation '
'assumes that `loss_fn` is a "reduction" loss function that '
"reduces the per-example losses by taking their *sum*. If "
"`loss_fn` instead reduces the per-example losses by taking their "
'mean, please set the reduction attribute of `loss_fn` to "mean", '
'i.e. `loss_fn.reduction = "mean"`.'
)
self.reduction_type = "sum"

def _influence_batch_tracincp_fast(
self,
Expand Down Expand Up @@ -347,6 +369,11 @@ def _basic_computation_tracincp_fast(

Args:
influence_instance (TracInCPFast): A instance of TracInCPFast or its children.
We assume `influence_instance` has a `loss_fn` attribute, i.e. the loss
function applied to the output of the last fully-connected layer, as
well as a `reduction_type` attribute, which indicates whether `loss_fn`
reduces the per-example losses by using their mean or sum. The
`reduction_type` attribute must either be "mean" or "sum".
inputs (Tuple of Any): A batch of examples, which could be a training batch
or test batch, depending which method is the caller. Does not
represent labels, which are passed as `targets`. The assumption is
Expand All @@ -360,13 +387,23 @@ def _basic_computation_tracincp_fast(
handle = influence_instance.final_fc_layer.register_forward_hook(_capture_inputs)
out = influence_instance.model(*inputs)

assert influence_instance.loss_fn is not None
assert influence_instance.loss_fn is not None, "loss function is required"
assert influence_instance.reduction_type in [
"sum",
"mean",
], 'reduction_type must be either "mean" or "sum"'
input_jacobians = _jacobian_loss_wrt_inputs(
influence_instance.loss_fn, out, targets, influence_instance.vectorize
influence_instance.loss_fn,
out,
targets,
influence_instance.vectorize,
influence_instance.reduction_type,
)
handle.remove()
_layer_inputs = layer_inputs[0]

assert len(input_jacobians.shape) == 2

return input_jacobians, _layer_inputs


Expand Down
73 changes: 68 additions & 5 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,85 @@ def _gradient_dot_product(


def _jacobian_loss_wrt_inputs(
loss_fn: Union[Module, Callable], out: Tensor, targets: Tensor, vectorize: bool
loss_fn: Union[Module, Callable],
out: Tensor,
targets: Tensor,
vectorize: bool,
reduction_type: str,
) -> Tensor:
r"""
Helper function to handle dealing with pytorch version differences for vectorized
jacobian calculation of loss wrt inputs.
Often, we have a loss function that computes a per-sample loss given a 1D tensor
input, and we want to calculate the jacobian of the loss w.r.t. that input. For
example, the input could be a length K tensor specifying the probability a given
sample belongs to each of K possible classes, and the loss function could be
cross-entropy loss. This function performs that calculation, but does so for a
*batch* of inputs. We create this helper function for two reasons: 1) to handle
differences between Pytorch versiosn for vectorized jacobian calculations, and
2) this function does not accept the aforementioned per-sample loss function.
Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss
for a batch into a single loss. Using a "reduction" loss improves speed.
We will allow this reduction to either be the mean or sum of the per-sample losses,
and this function provides an uniform way to handle different possible reductions,
and also check if the reduction used is valid. Regardless of the reduction used,
this function returns the jacobian for the per-sample loss (for each sample in the
batch).

Args:
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
defined loss function is provided, it would be expected to be a
torch.nn.Module. If a custom loss is provided, it can be either type,
but must behave as a library loss function would if `reduction='sum'`
or `reduction='mean'`.
out (tensor): This is a tensor that represents the batch of inputs to
`loss_fn`. In practice, this will be the output of a model; this is
why this argument is named `out`. `out` is a 2D tensor of shape
(batch size, model output dimensionality). We will call `loss_fn` via
`loss_fn(out, targets)`.
targets (tensor): The labels for the batch of inputs.
vectorize (bool): Flag to use experimental vectorize functionality for
`torch.autograd.functional.jacobian`.
reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn`
has the "reduction" attribute, we will check that they match. Can
only be "mean" or "sum".

Returns:
jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly
defined by `loss_fn` and `reduction_type`) w.r.t each sample
in the batch represented by `out`. This is a 2D tensor, where the
first dimension is the batch dimension.
"""
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"

assert loss_fn.reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1

if reduction_type != "sum" and reduction_type != "mean":
raise ValueError(
f"{reduction_type} is not a valid value for reduction_type. "
"Must be either 'sum' or 'mean'."
)

if torch.__version__ >= "1.8":
return torch.autograd.functional.jacobian(
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out, vectorize=vectorize
)
else:
return torch.autograd.functional.jacobian(
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out
)

if reduction_type == "mean":
input_jacobians = input_jacobians * len(input_jacobians)

return input_jacobians


def _load_flexible_state_dict(
model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions tests/influence/_core/test_tracin_get_k_most_influential.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TestTracInGetKMostInfluential(BaseTest):
),
("sum", DataInfluenceConstructor(TracInCPFast)),
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("mean", DataInfluenceConstructor(TracInCPFast)),
("mean", DataInfluenceConstructor(TracInCPFastRandProj)),
]
],
name_func=build_test_name_func(),
Expand Down
6 changes: 6 additions & 0 deletions tests/influence/_core/test_tracin_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _test_tracin_regression_setup(self, tmpdir: str, features: int):
("sample_wise_trick", None, DataInfluenceConstructor(TracInCP)),
("check_idx", "sum", DataInfluenceConstructor(TracInCPFast)),
("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)),
("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)),
(
"check_idx",
"sum",
Expand Down Expand Up @@ -168,6 +170,8 @@ def test_tracin_regression(
),
("sum", DataInfluenceConstructor(TracInCPFast)),
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("mean", DataInfluenceConstructor(TracInCPFast)),
("mean", DataInfluenceConstructor(TracInCPFastRandProj)),
],
name_func=build_test_name_func(),
)
Expand Down Expand Up @@ -248,6 +252,8 @@ def _test_tracin_identity_regression_setup(self, tmpdir: str):
("sample_wise_trick", None, DataInfluenceConstructor(TracInCP)),
("check_idx", "sum", DataInfluenceConstructor(TracInCPFast)),
("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)),
("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)),
],
name_func=build_test_name_func(),
)
Expand Down
1 change: 1 addition & 0 deletions tests/influence/_core/test_tracin_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TestTracInSelfInfluence(BaseTest):
),
),
("sum", DataInfluenceConstructor(TracInCPFast)),
("mean", DataInfluenceConstructor(TracInCPFast)),
]
],
name_func=build_test_name_func(args_to_skip=["reduction"]),
Expand Down