Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recall at fixed precision does not scale well for large batch sizes #2041

Closed
GlavitsBalazs opened this issue Sep 1, 2023 · 2 comments
Closed
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.1.x

Comments

@GlavitsBalazs
Copy link
Contributor

🐛 Bug

I would like to measure BinaryRecallAtFixedPrecision on a validation dataset with size on the order of hundreds of thousands. When I try to train my model, execution hangs at the validation step. My profiler tells me that the method torchmetrics.functional.classification.recall_fixed_precision._recall_at_precision is consuming endless amounts of CPU time. This happens because I use BinaryRecallAtFixedPrecision.thresholds=None, so BinaryRecallAtFixedPrecision.update() appends all of my validation batches together and therefore BinaryRecallAtFixedPrecision.compute() has to process the entire validation dataset at once.

To Reproduce

Set NUM_PROCESSES = 1 in tests/unittests/conftest.py for the sake of reproducibility.

Measure the run time of test_binary_recall_at_fixed_precision with the following command

time pytest tests/unittests/classification/test_recall_fixed_precision.py::TestBinaryRecallAtFixedPrecision::test_binary_recall_at_fixed_precision

It takes 8-9 seconds on my machine.
Now increase increase the testing batch size, by setting BATCH_SIZE in tests/unittests/conftest.py . This simulates the real world scenario of running BinaryRecallAtFixedPrecision.compute() with BinaryRecallAtFixedPrecision.thresholds=None and a large validation dataset.
Here are my crude measurements:

Batch Size Test Time (s)
100 ~9
1000 ~20
2000 ~30
5000 ~70
10000 ~124
100000 ~1200

Expected behavior

Yes, the run time scales linearly with input size, however it is not fast enough. Other metrics derived from BinaryPrecisionRecallCurve run much faster, and I don't see why this one should be particularly slower than those.

I see one area where potential improvements could be made: This line of code in torchmetrics.functional.classification.recall_fixed_precision._recall_at_precision. My profiler measured that this is where the majority of CPU time is spent during computation of BinaryRecallAtFixedPrecision. I suspect that finding the maxima of tensors by iterating through them with a Python for loop is a major source of inefficiency. Using native PyTorch operations instead could result in great speedups.

I'm proposing the following changes to the method in question. Measurement of my code under the same conditions as before gives the following results:

Batch Size Test Time (s)
100 ~8
1000 ~9
2000 ~10
5000 ~14
10000 ~19
100000 ~110

This alone may not be a conclusive benchmark, but the time savings during model validation are noticeable in my experience.

Environment

No GPU, no CUDA, CPU only.

torchmetrics==1.1.1
torch==2.0.1+cpu
python_version==3.11.5
# and others
@GlavitsBalazs GlavitsBalazs added bug / fix Something isn't working help wanted Extra attention is needed labels Sep 1, 2023
@github-actions
Copy link

github-actions bot commented Sep 1, 2023

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @GlavitsBalazs, thanks for raising this issue.
That is a huge finding. We in general do not necessarily test how well some of our implementations actually scale with the number of samples. We often try to do it while implementing them, but we may have missed it here.
Feel free to send a PR with the proposed fix (else I do it myself in one of the following days).
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.1.x
Projects
None yet
Development

No branches or pull requests

3 participants