Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The sdr metric in TM sometime gives NaN for some input #895

Closed
quancs opened this issue Mar 19, 2022 · 3 comments · Fixed by #899
Closed

The sdr metric in TM sometime gives NaN for some input #895

quancs opened this issue Mar 19, 2022 · 3 comments · Fixed by #899
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Audio
Milestone

Comments

@quancs
Copy link
Member

quancs commented Mar 19, 2022

🐛 Bug

This issue is related with fast-bss-eval's torch version, see fakufaku/fast_bss_eval#5

To Reproduce

import numpy as np
import torch

x = np.load('debug.npz')
preds = torch.tensor(x['preds'])
target = torch.tensor(x['target'])
print(preds.shape, target.shape)

from torchmetrics.functional.audio import signal_distortion_ratio
sdr = signal_distortion_ratio(preds, target)
print(sdr)

from mir_eval.separation import bss_eval_sources
sdr, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False)
print(sdr)

outputs:

torch.Size([2, 64000]) torch.Size([2, 64000])
tensor([-2.6815,     nan])
[-2.68156071 44.58523729]

unzip data.zip to get the debug.npz

Code sample

Expected behavior

the results given by signal_distortion_ratio is close to the one given by mir_eval

Environment

  • OS (e.g., Linux):
  • Python & PyTorch Version (e.g., 1.0):
  • How you installed PyTorch (conda, pip, build command if you used source):
  • Any other relevant information:

Additional context

@quancs quancs added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 19, 2022
@SkafteNicki
Copy link
Member

Hi @quancs,
From the issue that you linked it seems that the solution from the author is basically to do the evaluation in double instead of float. I can confirm that doing this fixes the example you send. Do you think we should cast the users input here:
https://github.com/PyTorchLightning/metrics/blob/865a08fcf102c2eb1b776b13643bc87aadf7f4f7/torchmetrics/functional/audio/sdr.py#L140-L141
to double instead of float. Alternatively, we can insert note in docstring that in some cases it is better to evaluate using double precision.

@quancs
Copy link
Member Author

quancs commented Mar 21, 2022

Hi @quancs, From the issue that you linked it seems that the solution from the author is basically to do the evaluation in double instead of float. I can confirm that doing this fixes the example you send. Do you think we should cast the users input here:

https://github.com/PyTorchLightning/metrics/blob/865a08fcf102c2eb1b776b13643bc87aadf7f4f7/torchmetrics/functional/audio/sdr.py#L140-L141

to double instead of float. Alternatively, we can insert note in docstring that in some cases it is better to evaluate using double precision.

This issue happens on the torch version on cpu. On GPU it's OK.
And I don't see any vialation in my past experiment results tested on GPU.
I have some ideas to fix this:

  1. convert to double anyway, no matter on CPU or GPU, but it may make the metric slow
  2. convert to double on CPU, but it may make the metric slow, and we don't know whether on GPU is really OK
  3. convert to double when we detect the result is not a valid number (NaN or Inf), and run again.

I prefer 3). what do you think? or do you have other ideas?

@SkafteNicki
Copy link
Member

  1. is fine by me :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Audio
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants