From 67113757e939c45eb8485cb82ee78a8ee6950f4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bal=C3=A1zs=20Gl=C3=A1vits?= Date: Fri, 1 Sep 2023 14:52:04 +0000 Subject: [PATCH] Improved the performance of RecallAtFixedPrecision for large batch sizes (#2042) Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 694399abc0f4647847abb9bb8fbdc9732aa307aa) --- CHANGELOG.md | 1 + .../classification/recall_fixed_precision.py | 36 ++++++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6b93947d27..b8e03ddacf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042)) - diff --git a/src/torchmetrics/functional/classification/recall_fixed_precision.py b/src/torchmetrics/functional/classification/recall_fixed_precision.py index 725d0ad7eae..3cac182e519 100644 --- a/src/torchmetrics/functional/classification/recall_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_fixed_precision.py @@ -37,21 +37,39 @@ from torchmetrics.utilities.enums import ClassificationTask +def _lexargmax(x: Tensor) -> Tensor: + """Returns the index of the maximum value in a list of tuples according to lexicographic ordering. + + Based on https://stackoverflow.com/a/65615160 + + """ + idx: Optional[Tensor] = None + for k in range(x.shape[1]): + col: Tensor = x[idx, k] if idx is not None else x[:, k] + z = torch.where(col == col.max())[0] + idx = z if idx is None else idx[z] + if len(idx) < 2: + break + if idx is None: + raise ValueError("Failed to extract index") + return idx + + def _recall_at_precision( precision: Tensor, recall: Tensor, thresholds: Tensor, min_precision: float, ) -> Tuple[Tensor, Tensor]: - try: - max_recall, _, best_threshold = max( - (r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision - ) - - except ValueError: - max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) - best_threshold = torch.tensor(0) - + max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) + best_threshold = torch.tensor(0) + + zipped_len = min(t.shape[0] for t in (recall, precision, thresholds)) + zipped = torch.vstack((recall[:zipped_len], precision[:zipped_len], thresholds[:zipped_len])).T + zipped_masked = zipped[zipped[:, 1] >= min_precision] + if zipped_masked.shape[0] > 0: + idx = _lexargmax(zipped_masked)[0] + max_recall, _, best_threshold = zipped_masked[idx] if max_recall == 0.0: best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)