Skip to content

Commit

Permalink
Improved the performance of RecallAtFixedPrecision for large batch si…
Browse files Browse the repository at this point in the history
…zes (#2042)

Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 694399a)
  • Loading branch information
GlavitsBalazs authored and Borda committed Sep 11, 2023
1 parent 9120e4a commit 6711375
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))
-


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,39 @@
from torchmetrics.utilities.enums import ClassificationTask


def _lexargmax(x: Tensor) -> Tensor:
"""Returns the index of the maximum value in a list of tuples according to lexicographic ordering.
Based on https://stackoverflow.com/a/65615160
"""
idx: Optional[Tensor] = None
for k in range(x.shape[1]):
col: Tensor = x[idx, k] if idx is not None else x[:, k]
z = torch.where(col == col.max())[0]
idx = z if idx is None else idx[z]
if len(idx) < 2:
break
if idx is None:
raise ValueError("Failed to extract index")
return idx


def _recall_at_precision(
precision: Tensor,
recall: Tensor,
thresholds: Tensor,
min_precision: float,
) -> Tuple[Tensor, Tensor]:
try:
max_recall, _, best_threshold = max(
(r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision
)

except ValueError:
max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype)
best_threshold = torch.tensor(0)

max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype)
best_threshold = torch.tensor(0)

zipped_len = min(t.shape[0] for t in (recall, precision, thresholds))
zipped = torch.vstack((recall[:zipped_len], precision[:zipped_len], thresholds[:zipped_len])).T
zipped_masked = zipped[zipped[:, 1] >= min_precision]
if zipped_masked.shape[0] > 0:
idx = _lexargmax(zipped_masked)[0]
max_recall, _, best_threshold = zipped_masked[idx]
if max_recall == 0.0:
best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)

Expand Down

0 comments on commit 6711375

Please sign in to comment.