-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mean Average Precision metric for Information Retrieval (1/5) (#5032)
* init information retrieval metrics * changed retrieval metrics names, expanded arguments and fixed typo * added 'Retrieval' prefix to metrics and fixed conflict with already-present 'average_precision' file * improved code formatting * pep8 code compatibility * features/implemented new Mean Average Precision metrics for Information Retrieval + doc * fixed pep8 compatibility * removed threshold parameter and fixed typo on types in RetrievalMAP and improved doc * improved doc, put first class-specific args in RetrievalMetric and transformed RetrievalMetric in abstract class * implemented tests for functional and class metric. fixed typo when input tensors are empty or when all targets are False * fixed typos in doc and changed torch.true_divide to torch.div * fixed typos pep8 compatibility * fixed types in long division in ir_average_precision and example in mean_average_precision * RetrievalMetric states are not lists and _metric method accepts predictions and targets for easier extension * updated CHANGELOG file * added '# noqa: F401' flag to not used imports * added double space before '# noqa: F401' flag * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * change get_mini_groups in get_group_indexes * added checks on target inputs * minor refactoring for code cleanness * split tests over exception raising in separate function && refactored test code into multiple functions * fixed pep8 compatibility * implemented suggestions of @SkafteNicki * fixed imports for isort and added types annontations to functions in test_map.py * isort on test_map and fixed typing * isort on retrieval and on __init__.py and utils.py in metrics package * fixed typo in pytorch_lightning/metrics/__init__.py regarding code style * fixed yapf compatibility * fixed yapf compatibility * fixed typo in doc Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
06756a8
commit 5d73fbb
Showing
12 changed files
with
484 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,3 +37,4 @@ | |
R2Score, | ||
SSIM, | ||
) | ||
from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
pytorch_lightning/metrics/functional/ir_average_precision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
|
||
|
||
def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Computes average precision (for information retrieval), as explained | ||
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_. | ||
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, | ||
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. | ||
Args: | ||
preds: estimated probabilities of each document to be relevant. | ||
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. | ||
Return: | ||
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. | ||
Example: | ||
>>> preds = torch.tensor([0.2, 0.3, 0.5]) | ||
>>> target = torch.tensor([True, False, True]) | ||
>>> retrieval_average_precision(preds, target) | ||
tensor(0.8333) | ||
""" | ||
|
||
if preds.shape != target.shape or preds.device != target.device: | ||
raise ValueError("`preds` and `target` must have the same shape and live on the same device") | ||
|
||
if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): | ||
raise ValueError("`target` must be a tensor of booleans or integers") | ||
|
||
if target.dtype is not torch.bool: | ||
target = target.bool() | ||
|
||
if target.sum() == 0: | ||
return torch.tensor(0, device=preds.device) | ||
|
||
target = target[torch.argsort(preds, dim=-1, descending=True)] | ||
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] | ||
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 | ||
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 |
61 changes: 61 additions & 0 deletions
61
pytorch_lightning/metrics/retrieval/mean_average_precision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
|
||
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision | ||
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric | ||
|
||
|
||
class RetrievalMAP(RetrievalMetric): | ||
r""" | ||
Computes `Mean Average Precision | ||
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_. | ||
Works with binary data. Accepts integer or float predictions from a model output. | ||
Forward accepts | ||
- ``indexes`` (long tensor): ``(N, ...)`` | ||
- ``preds`` (float tensor): ``(N, ...)`` | ||
- ``target`` (long or bool tensor): ``(N, ...)`` | ||
`indexes`, `preds` and `target` must have the same dimension. | ||
`indexes` indicate to which query a prediction belongs. | ||
Predictions will be first grouped by indexes and then MAP will be computed as the mean | ||
of the Average Precisions over each query. | ||
Args: | ||
query_without_relevant_docs: | ||
Specify what to do with queries that do not have at least a positive target. Choose from: | ||
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned | ||
- ``'error'``: raise a ``ValueError`` | ||
- ``'pos'``: score on those queries is counted as ``1.0`` | ||
- ``'neg'``: score on those queries is counted as ``0.0`` | ||
exclude: | ||
Do not take into account predictions where the target is equal to this value. default `-100` | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects | ||
the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When `None`, DDP | ||
will be used to perform the allgather. default: None | ||
Example: | ||
>>> from pytorch_lightning.metrics import RetrievalMAP | ||
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) | ||
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) | ||
>>> target = torch.tensor([False, False, True, False, True, False, False]) | ||
>>> map = RetrievalMAP() | ||
>>> map(indexes, preds, target) | ||
tensor(0.7500) | ||
>>> map.compute() | ||
tensor(0.7500) | ||
""" | ||
|
||
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
valid_indexes = target != self.exclude | ||
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) |
140 changes: 140 additions & 0 deletions
140
pytorch_lightning/metrics/retrieval/retrieval_metric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
|
||
from pytorch_lightning.metrics import Metric | ||
from pytorch_lightning.metrics.utils import get_group_indexes | ||
|
||
#: get_group_indexes is used to group predictions belonging to the same query | ||
|
||
IGNORE_IDX = -100 | ||
|
||
|
||
class RetrievalMetric(Metric, ABC): | ||
r""" | ||
Works with binary data. Accepts integer or float predictions from a model output. | ||
Forward accepts | ||
- ``indexes`` (long tensor): ``(N, ...)`` | ||
- ``preds`` (float or int tensor): ``(N, ...)`` | ||
- ``target`` (long or bool tensor): ``(N, ...)`` | ||
`indexes`, `preds` and `target` must have the same dimension and will be flatten | ||
to single dimension once provided. | ||
`indexes` indicate to which query a prediction belongs. | ||
Predictions will be first grouped by indexes. Then the | ||
real metric, defined by overriding the `_metric` method, | ||
will be computed as the mean of the scores over each query. | ||
Args: | ||
query_without_relevant_docs: | ||
Specify what to do with queries that do not have at least a positive target. Choose from: | ||
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned | ||
- ``'error'``: raise a ``ValueError`` | ||
- ``'pos'``: score on those queries is counted as ``1.0`` | ||
- ``'neg'``: score on those queries is counted as ``0.0`` | ||
exclude: | ||
Do not take into account predictions where the target is equal to this value. default `-100` | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects | ||
the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When `None`, DDP | ||
will be used to perform the allgather. default: None | ||
""" | ||
|
||
def __init__( | ||
self, | ||
query_without_relevant_docs: str = 'skip', | ||
exclude: int = IGNORE_IDX, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
dist_sync_fn: Callable = None | ||
): | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn | ||
) | ||
|
||
query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') | ||
if query_without_relevant_docs not in query_without_relevant_docs_options: | ||
raise ValueError( | ||
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. " | ||
f"Allowed values are {query_without_relevant_docs_options}" | ||
) | ||
|
||
self.query_without_relevant_docs = query_without_relevant_docs | ||
self.exclude = exclude | ||
|
||
self.add_state("idx", default=[], dist_reduce_fx=None) | ||
self.add_state("preds", default=[], dist_reduce_fx=None) | ||
self.add_state("target", default=[], dist_reduce_fx=None) | ||
|
||
def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None: | ||
if not (idx.shape == target.shape == preds.shape): | ||
raise ValueError("`idx`, `preds` and `target` must be of the same shape") | ||
|
||
idx = idx.to(dtype=torch.int64).flatten() | ||
preds = preds.to(dtype=torch.float32).flatten() | ||
target = target.to(dtype=torch.int64).flatten() | ||
|
||
self.idx.append(idx) | ||
self.preds.append(preds) | ||
self.target.append(target) | ||
|
||
def compute(self) -> torch.Tensor: | ||
r""" | ||
First concat state `idx`, `preds` and `target` since they were stored as lists. After that, | ||
compute list of groups that will help in keeping together predictions about the same query. | ||
Finally, for each group compute the `_metric` if the number of positive targets is at least | ||
1, otherwise behave as specified by `self.query_without_relevant_docs`. | ||
""" | ||
|
||
idx = torch.cat(self.idx, dim=0) | ||
preds = torch.cat(self.preds, dim=0) | ||
target = torch.cat(self.target, dim=0) | ||
|
||
res = [] | ||
kwargs = {'device': idx.device, 'dtype': torch.float32} | ||
|
||
groups = get_group_indexes(idx) | ||
for group in groups: | ||
|
||
mini_preds = preds[group] | ||
mini_target = target[group] | ||
|
||
if not mini_target.sum(): | ||
if self.query_without_relevant_docs == 'error': | ||
raise ValueError( | ||
f"`{self.__class__.__name__}.compute()` was provided with " | ||
f"a query without positive targets, indexes: {group}" | ||
) | ||
if self.query_without_relevant_docs == 'pos': | ||
res.append(torch.tensor(1.0, **kwargs)) | ||
elif self.query_without_relevant_docs == 'neg': | ||
res.append(torch.tensor(0.0, **kwargs)) | ||
else: | ||
res.append(self._metric(mini_preds, mini_target)) | ||
|
||
if len(res) > 0: | ||
return torch.stack(res).mean() | ||
return torch.tensor(0.0, **kwargs) | ||
|
||
@abstractmethod | ||
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Compute a metric over a predictions and target of a single group. | ||
This method should be overridden by subclasses. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.