Skip to content

Commit

Permalink
add compute_intermediate_quantities to TracInCP (pytorch#1068)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1068

This diff adds the `compute_intermediate_quantities` method to `TracInCP`, which returns influence embeddings such that the influence of one example on another is the dot-product of their respective influence embeddings. In the case of `TracInCP`, its influence embeddings are simply the parameter-gradients for an example, concatenated over different checkpoints.

There is also an `aggregate` option that if True, returns not the influence embeddings of each example in the given dataset, but instead their *sum*. This is useful for the validation diff workflow (which is the next diff in the stack), where we want to calculate the influence of a given training example on an entire validation dataset. This can be accomplished by taking the dot-product of the training example's influence embedding with the *sum* of the influence embeddings over the validation dataset (i.e. with `aggregate=True`)

For tests, the tests currently used for `TracInCPFastRandProj.compute_intermediate_quantities` (`test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_api`, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_consistent`) are applied to `TracInCP.compute_intermediate_quantities`. In addition, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_aggregate` is added to test the `aggregate=True` option, checking that with `aggregate=True`, the returned influence embedding is indeed the sum of the influence embeddings for the given dataset.

Reviewed By: cyrjano

Differential Revision: D40688327

fbshipit-source-id: 505eddc34da93391975ba2579abd8dcc9c7560c5
  • Loading branch information
Fulton Wang authored and facebook-github-bot committed Nov 18, 2022
1 parent 3977f79 commit befd927
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
137 changes: 137 additions & 0 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,142 @@ def influence( # type: ignore[override]
show_progress,
)

def _sum_jacobians(self, inputs_dataset: DataLoader):
"""
sums the jacobians of all examples in `inputs_dataset`. result is of the
same format as layer_jacobians, but the batch dimension has size 1
"""
inputs_dataset_iter = iter(inputs_dataset)

inputs_batch = next(inputs_dataset_iter)

def get_batch_contribution(inputs_batch):
_input_jacobians = self._basic_computation_tracincp(
inputs_batch[0:-1],
inputs_batch[-1],
)

return tuple(
torch.sum(jacobian, dim=0).unsqueeze(0) for jacobian in _input_jacobians
)

inputs_jacobians = get_batch_contribution(inputs_batch)

for inputs_batch in inputs_dataset_iter:
inputs_batch_jacobians = get_batch_contribution(inputs_batch)
inputs_jacobians = tuple(
[
inputs_jacobian + inputs_batch_jacobian
for (inputs_jacobian, inputs_batch_jacobian) in zip(
inputs_jacobians, inputs_batch_jacobians
)
]
)

return inputs_jacobians

def _concat_jacobians(self, inputs_dataset: DataLoader):
all_inputs_batch_jacobians = [
self._basic_computation_tracincp(
inputs_batch[0:-1],
inputs_batch[-1],
)
for inputs_batch in inputs_dataset
]

return tuple(
torch.cat(all_inputs_batch_jacobian, dim=0)
for all_inputs_batch_jacobian in zip(*all_inputs_batch_jacobians)
)

@log_usage
def compute_intermediate_quantities(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
aggregate: bool = False,
) -> Tensor:
"""
Computes "embedding" vectors for all examples in a single batch, or a
`Dataloader` that yields batches. These embedding vectors are constructed so
that the influence score of a training example on a test example is simply the
dot-product of their corresponding vectors. Allowing a `DataLoader`
yielding batches to be passed in (as opposed to a single batch) gives the
potential to improve efficiency, because we load each checkpoint only once in
this method call. Thus if a `DataLoader` yielding batches is passed in, this
reduces the total number of times each checkpoint is loaded for a dataset,
compared to if a single batch is passed in. The reason we do not just increase
the batch size is that for large models, large batches do not fit in memory.
If `aggregate` is True, the *sum* of the vectors for all examples is returned,
instead of the vectors for each example. This can be useful for computing the
influence of a given training example on the total loss over a validation
dataset, because due to properties of the dot-product, this influence is the
dot-product of the training example's vector with the sum of the vectors in the
validation dataset. Also, by doing the sum aggregation within this method as
opposed to outside of it (by computing all vectors for the validation dataset,
then taking the sum) allows memory usage to be reduced.
Args:
inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a
`DataLoader`, where each batch yielded is a tuple of any. In
either case, the tuple represents a single batch, where the last
element is assumed to be the labels for the batch. That is,
`model(*batch[0:-1])` produces the output for `model`, and
and `batch[-1]` are the labels, if any. Here, `model` is model
provided in initialization. This is the same assumption made for
each batch yielded by training dataset `train_dataset`.
Returns:
intermediate_quantities (Tensor): A tensor of dimension
(N, D * C). Here, N is the total number of examples in
`inputs_dataset` if `aggregate` is False, and 1, otherwise (so that
a 2D tensor is always returned). C is the number of checkpoints
passed as the `checkpoints` argument of `TracInCP.__init__`, and
each row represents the vector for an example. Regarding D: Let I
be the dimension of the output of the last fully-connected layer
times the dimension of the input of the last fully-connected layer.
If `self.projection_dim` is specified in initialization,
D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C.
In summary, if `self.projection_dim` is None, the dimension of each
vector will be determined by the size of the input and output of
the last fully-connected layer of `model`. Otherwise,
`self.projection_dim` must be an int, and random projection will be
performed to ensure that the vector is of dimension no more than
`self.projection_dim` * C. `self.projection_dim` corresponds to
the variable d in the top of page 15 of the TracIn paper:
https://arxiv.org/pdf/2002.08484.pdf.
"""
# If `inputs_dataset` is not a `DataLoader`, turn it into one.
inputs_dataset = _format_inputs_dataset(inputs_dataset)

def get_checkpoint_contribution(checkpoint):
assert (
checkpoint is not None
), "None returned from `checkpoints`, cannot load."

learning_rate = self.checkpoints_load_func(self.model, checkpoint)
# get jacobians as tuple of tensors
if aggregate:
inputs_jacobians = self._sum_jacobians(inputs_dataset)
else:
inputs_jacobians = self._concat_jacobians(inputs_dataset)
# flatten into single tensor
return learning_rate * torch.cat(
[
input_jacobian.flatten(start_dim=1)
for input_jacobian in inputs_jacobians
],
dim=1,
)

return torch.cat(
[
get_checkpoint_contribution(checkpoint)
for checkpoint in self.checkpoints
],
dim=1,
)

def _influence_batch_tracincp(
self,
inputs: Tuple[Any, ...],
Expand Down Expand Up @@ -1109,6 +1245,7 @@ def get_checkpoint_contribution(checkpoint):

return batches_self_tracin_scores

@log_usage()
def self_influence(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
Expand Down
3 changes: 3 additions & 0 deletions captum/influence/_core/tracincp_fast_rand_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def get_checkpoint_contribution(checkpoint):
checkpoints_progress.update()
return batches_self_tracin_scores

@log_usage()
def self_influence(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
Expand Down Expand Up @@ -1093,6 +1094,7 @@ def _get_k_most_influential( # type: ignore[override]

return KMostInfluentialResults(indices, distances)

@log_usage()
def self_influence(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
Expand Down Expand Up @@ -1531,6 +1533,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
# each row in this result is the "embedding" vector for an example in `batch`
return torch.cat(checkpoint_contributions, dim=1) # type: ignore

@log_usage()
def compute_intermediate_quantities(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
Expand Down
62 changes: 62 additions & 0 deletions tests/influence/_core/test_tracin_intermediate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import torch.nn as nn
from captum.influence._core.tracincp import TracInCP
from captum.influence._core.tracincp_fast_rand_proj import (
TracInCPFast,
TracInCPFastRandProj,
Expand All @@ -19,12 +20,68 @@


class TestTracInIntermediateQuantities(BaseTest):
@parameterized.expand(
[
(reduction, constructor, unpack_inputs)
for unpack_inputs in [True, False]
for (reduction, constructor) in [
("none", DataInfluenceConstructor(TracInCP)),
]
],
name_func=build_test_name_func(),
)
def test_tracin_intermediate_quantities_aggregate(
self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool
) -> None:
"""
tests that calling `compute_intermediate_quantities` with `aggregate=True`
does give the same result as calling it with `aggregate=False`, and then
summing
"""
with tempfile.TemporaryDirectory() as tmpdir:
(net, train_dataset,) = get_random_model_and_data(
tmpdir,
unpack_inputs,
return_test_data=False,
)

# create a dataloader that yields batches from the dataset
train_dataset = DataLoader(train_dataset, batch_size=5)

# create tracin instance
criterion = nn.MSELoss(reduction=reduction)
batch_size = 5

tracin = tracin_constructor(
net,
train_dataset,
tmpdir,
batch_size,
criterion,
)

intermediate_quantities = tracin.compute_intermediate_quantities(
train_dataset, aggregate=False
)
aggregated_intermediate_quantities = tracin.compute_intermediate_quantities(
train_dataset, aggregate=True
)

assertTensorAlmostEqual(
self,
torch.sum(intermediate_quantities, dim=0, keepdim=True),
aggregated_intermediate_quantities,
delta=1e-4, # due to numerical issues, we can't set this to 0.0
mode="max",
)

@parameterized.expand(
[
(reduction, constructor, unpack_inputs)
for unpack_inputs in [True, False]
for (reduction, constructor) in [
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("none", DataInfluenceConstructor(TracInCP)),
]
],
name_func=build_test_name_func(),
Expand Down Expand Up @@ -103,6 +160,11 @@ def test_tracin_intermediate_quantities_api(
DataInfluenceConstructor(TracInCPFast),
DataInfluenceConstructor(TracInCPFastRandProj),
),
(
"none",
DataInfluenceConstructor(TracInCP),
DataInfluenceConstructor(TracInCP),
),
]
],
name_func=build_test_name_func(),
Expand Down

0 comments on commit befd927

Please sign in to comment.