Skip to content

Commit

Permalink
re-implement the signal_distortion_ratio metric (#964)
Browse files Browse the repository at this point in the history
* reimplement signal_distortion_ratio
* sdr is differentiable for all supported pytorch version now
* update & fix
* chlog
* Apply suggestions from code review

Co-authored-by: quancs <quancs@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
6 people authored Apr 21, 2022
1 parent b83edf0 commit f8ef656
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 138 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval` ([#964](https://github.com/PyTorchLightning/metrics/pull/964))


-
Expand Down
2 changes: 0 additions & 2 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
pesq>=0.0.3
pystoi
fast-bss-eval>=0.1.0
torch_complex # needed for fast-bss-eval torch<=1.7
2 changes: 2 additions & 0 deletions requirements/audio_test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pypesq
mir_eval>=0.6
speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip
fast-bss-eval>=0.1.0
torch_complex # needed for fast-bss-eval torch<=1.7
12 changes: 2 additions & 10 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tests.helpers.testers import MetricTester
from torchmetrics.audio import SignalDistortionRatio
from torchmetrics.functional import signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12_DEV
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Expand Down Expand Up @@ -99,7 +99,6 @@ def test_sdr_functional(self, preds, target, sk_metric):
metric_args=dict(),
)

@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="sdr is not differentiable for pytorch < 1.8")
def test_sdr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(
preds=preds,
Expand Down Expand Up @@ -155,14 +154,7 @@ def test_too_low_precision():
preds = torch.tensor(data["preds"])
target = torch.tensor(data["target"])

if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12_DEV:
with pytest.warns(
UserWarning,
match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision",
):
sdr_tm = signal_distortion_ratio(preds, target)
else: # when pytorch < 1.8 or pytorch >= 1.12, sdr doesn't have this problem
sdr_tm = signal_distortion_ratio(preds, target).double()
sdr_tm = signal_distortion_ratio(preds, target).double()

# check equality with bss_eval_sources in every pytorch version
sdr_bss, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False)
Expand Down
35 changes: 8 additions & 27 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE

__doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]}


class SignalDistortionRatio(Metric):
r"""Signal to Distortion Ratio (SDR) [1,2,3]
r"""Signal to Distortion Ratio (SDR) [1,2]
Forward accepts
Expand All @@ -32,10 +31,13 @@ class SignalDistortionRatio(Metric):
Args:
use_cg_iter:
If provided, an iterative method is used to solve for the distortion filter coefficients instead
of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient
when using this loss to train neural separation networks.
If provided, conjugate gradient descent is used to solve for the distortion
filter coefficients instead of direct Gaussian elimination, which requires that
``fast-bss-eval`` is installed and pytorch version >= 1.8.
This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide
good accuracy in most cases and is sufficient when using this
loss to train neural separation networks.
filter_length: The length of the distortion filter allowed
zero_mean:
When set to True, the mean of all signals is subtracted prior to computation of the metrics
Expand All @@ -51,10 +53,6 @@ class SignalDistortionRatio(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ModuleNotFoundError:
If ``fast-bss-eval`` package is not installed
Example:
>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
Expand All @@ -73,23 +71,11 @@ class SignalDistortionRatio(Metric):
>>> pit(preds, target)
tensor(-11.6051)
.. note::
1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be
non-differentiable and slower to calculate
2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install fast-bss-eval``
3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype
References:
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation.
IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.
[2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations.
[3] https://github.com/fakufaku/fast_bss_eval
"""

sum_sdr: Tensor
Expand All @@ -106,11 +92,6 @@ def __init__(
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
if not _FAST_BSS_EVAL_AVAILABLE:
raise ModuleNotFoundError(
"SDR metric requires that `fast-bss-eval` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`."
)
super().__init__(compute_on_step=compute_on_step, **kwargs)

self.use_cg_iter = use_cg_iter
Expand Down
Loading

0 comments on commit f8ef656

Please sign in to comment.