Skip to content

Commit

Permalink
add test loss (pytorch#1073)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1073

- For all `TracInCPBase` implementations, this adds an additional `test_loss_fn` initialization argument, which is the loss function to apply to test examples when computing the influence of a training example on a test example.  With this change,the influence score is a sum over terms for each checkpoint, where each term is the gradient of `loss_fn` for a given training example, multiplied with the gradient of `test_loss_fn` for a given test example. Before, `test_loss_fn` was assumed to be the same as `loss_fn`.
- checks regarding the reduction type of both `loss_fn` and `test_loss_fn` are now handled by helper functions `_check_tracincp_loss_fn` and `_check_tracincp_fast_loss_fn`.
- documentation is updated.  one detail: for `TracInCP`, we assume that `sample_wise_grads_per_batch` is applied to both `loss_fn` and `test_loss_fn` (if provided), and this is mentioned in the documentation.
- `test_tracin_regression.test_tracin_regression` is slightly modified - `DataInfluenceConstructor` now can explicitly pass in the same loss function for both `loss_fn` and `test_loss_fn` (done when `duplicate_loss_fn=True`). Doing so would have the same effect as not passing in `test_loss_fn`, so the original tests are also applied to the case when `duplicate_loss_fn=True`, as the expected behavior should be the same as before.
- a new test, `test_tracin_regression.test_tracin_constant_test_loss_fn` is added. For all implementations of `TracInCPBase`, it checks that if `test_loss_fn` is a constant loss function, the influence scores are all 0's. This should be the case, because if `test_loss_fn` is constant, its gradients would all be 0's, so that training examples have 0 influence on test examples.

Differential Revision: https://internalfb.com/D41202866

fbshipit-source-id: e6a261797c7e89d03e40b026001e00d7ec30853e
  • Loading branch information
99warriors authored and facebook-github-bot committed Dec 9, 2022
1 parent c29155b commit fba3034
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 88 deletions.
11 changes: 7 additions & 4 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
if labels is not None and loss_fn is not None:
loss = loss_fn(out, labels)
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr(
loss_fn, "reduction"
):
reduction = loss_fn.reduction # type: ignore
msg0 = (
"Please ensure that loss_fn.reduction is set to `sum` or `mean`"
)

assert loss_fn.reduction != "none", msg0
assert reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"loss_fn.reduction ({reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1
assert reduction == reduction_type, msg1
msg2 = (
"Please ensure custom loss function is applying either a "
"sum or mean reduction."
Expand Down
118 changes: 71 additions & 47 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from captum._utils.progress import NullProgress, progress
from captum.influence._core.influence import DataInfluence
from captum.influence._utils.common import (
_check_loss_fn,
_format_inputs_dataset,
_get_k_most_influential_helper,
_gradient_dot_product,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
checkpoints_load_func: Callable = _load_flexible_state_dict,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
) -> None:
r"""
Args:
Expand Down Expand Up @@ -152,6 +154,19 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs to satisfy the same constraints as `loss_fn`.
If not provided, the loss function for test examples is assumed to
be the same as the loss function for training examples, i.e.
`loss_fn`.
Default: None
"""

self.model = model
Expand All @@ -167,6 +182,8 @@ def __init__(

self.checkpoints_load_func = checkpoints_load_func
self.loss_fn = loss_fn
# If test_loss_fn not provided, it's assumed to be same as loss_fn
self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
self.batch_size = batch_size

if not isinstance(train_dataset, DataLoader):
Expand Down Expand Up @@ -489,6 +506,7 @@ def __init__(
layers: Optional[List[str]] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
sample_wise_grads_per_batch: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -561,6 +579,24 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
Thus, the same checks that we apply to `loss_fn` are also applied
to `test_loss_fn`, if the latter is provided. Note that the
constraints on both `loss_fn` and `test_loss_fn` both depend on
`sample_wise_grads_per_batch`. This means `loss_fn` and
`test_loss_fn` must either both be "per-example" loss functions,
or both be "reduction" loss functions. If not provided, the loss
function for test examples is assumed to be the same as the loss
function for training examples, i.e. `loss_fn`.
Default: None
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
computations w.r.t. model parameters aggregates the results for a
batch and does not allow to access sample-wise gradients w.r.t.
Expand Down Expand Up @@ -590,51 +626,23 @@ def __init__(
checkpoints_load_func,
loss_fn,
batch_size,
test_loss_fn,
)

self.sample_wise_grads_per_batch = sample_wise_grads_per_batch

# If we are able to access the reduction used by `loss_fn`, we check whether
# the reduction is compatible with `sample_wise_grads_per_batch`
if isinstance(loss_fn, Module) and hasattr(
loss_fn, "reduction"
): # TODO: allow loss_fn to be Callable
if self.sample_wise_grads_per_batch:
assert loss_fn.reduction in ["sum", "mean"], (
'reduction for `loss_fn` must be "sum" or "mean" when '
"`sample_wise_grads_per_batch` is True"
)
self.reduction_type = str(loss_fn.reduction)
else:
assert loss_fn.reduction == "none", (
'reduction for `loss_fn` must be "none" when '
"`sample_wise_grads_per_batch` is False"
)
else:
# if we are unable to access the reduction used by `loss_fn`, we warn
# the user about the assumptions we are making regarding the reduction
# used by `loss_fn`
if self.sample_wise_grads_per_batch:
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, and '
"`sample_wise_grads_per_batch` is True, the implementation assumes "
'that `loss_fn` is a "reduction" loss function that reduces the '
"per-example losses by taking their *sum*. If `loss_fn` "
"instead reduces the per-example losses by taking their mean, "
'please set the reduction attribute of `loss_fn` to "mean", i.e. '
'`loss_fn.reduction = "mean"`. Note that if '
"`sample_wise_grads_per_batch` is True, the implementation "
"assumes the reduction is either a sum or mean reduction."
)
self.reduction_type = "sum"
else:
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, and '
"`sample_wise_grads_per_batch` is False, the implementation "
'assumes that `loss_fn` is a "per-example" loss function (see '
"documentation for `loss_fn` for details). Please ensure that "
"this is the case."
)
# check `loss_fn`
self.reduction_type = _check_loss_fn(
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
)
# check `test_loss_fn` if it was provided
self.test_reduction_type = (
self.reduction_type
if test_loss_fn is None
else _check_loss_fn(
self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch
)
)

r"""
TODO: Either restore model state after done (would have to place functionality
Expand Down Expand Up @@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint):
input_jacobians = self._basic_computation_tracincp(
inputs,
targets,
self.test_loss_fn,
self.test_reduction_type,
)
return (
_gradient_dot_product(
input_jacobians,
self._basic_computation_tracincp(batch[0:-1], batch[-1]),
self._basic_computation_tracincp(
batch[0:-1], batch[-1], self.loss_fn, self.reduction_type
),
)
* learning_rate
)
Expand Down Expand Up @@ -1042,7 +1054,10 @@ def get_checkpoint_contribution(checkpoint):
for batch in _inputs_dataset:

layer_jacobians = self._basic_computation_tracincp(
batch[0:-1], batch[-1]
batch[0:-1],
batch[-1],
self.loss_fn,
self.reduction_type,
)

# Note that all variables in this function are for an entire batch.
Expand Down Expand Up @@ -1179,11 +1194,14 @@ def _basic_computation_tracincp(
self,
inputs: Tuple[Any, ...],
targets: Optional[Tensor] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
reduction_type: Optional[str] = None,
) -> Tuple[Tensor, ...]:
"""
For instances of TracInCP, computation of influence scores or self influence
scores repeatedly calls this function for different checkpoints
and batches.
and batches. In particular, this function computes the jacobian of a loss
function w.r.t. parameters in the `layers` initialization argument.
Args:
Expand All @@ -1193,20 +1211,26 @@ def _basic_computation_tracincp(
that `model(*inputs)` produces the predictions for the batch.
targets (tensor or None): If computing influence scores on a loss function,
these are the labels corresponding to the batch `inputs`.
Default: none
loss_fn (Callable, optional): The loss function to use when computing the
jacobian.
reduction_type (str, optional): The reduction type of `loss_fn`. This
argument is only used if `sample_wise_grads_per_batch` was true in
initialization.
"""
if self.sample_wise_grads_per_batch:
return _compute_jacobian_wrt_params_with_sample_wise_trick(
self.model,
inputs,
targets,
self.loss_fn,
self.reduction_type,
loss_fn,
reduction_type,
self.layer_modules,
)
return _compute_jacobian_wrt_params(
self.model,
inputs,
targets,
self.loss_fn,
loss_fn,
self.layer_modules,
)
Loading

0 comments on commit fba3034

Please sign in to comment.