Skip to content

Commit

Permalink
PSNR metric (#2483)
Browse files Browse the repository at this point in the history
* Add stub PSNR metric

* Fix linter

* Add data range as parameter

* Add tests

* Add scikit-image

* Add PSNR to regression metrics and add functional

* Refactor to functional

* Fix linter

* Fix linter, again

* Fix linter, again

* Fix typo in test

* Fix typo in another test

* Add scikit-image to conda

* Lift numpy requirement

* Add random tests

* Update CHANGELOG.md

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
InCogNiTo124 and Borda authored Jul 8, 2020
1 parent 899cd74 commit 1dc7242
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a PSNR metric: peak signal-to-noise ratio ([#2483](https://github.com/PyTorchLightning/pytorch-lightning/pull/2483))

### Changed

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- autopep8
- twine==1.13.0
- pillow<7.0.0
- scikit-image

# Optional
- scipy>=0.13.3
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
RMSE,
MAE,
RMSLE,
PSNR
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down Expand Up @@ -50,6 +51,7 @@
'MSE',
'RMSE',
'MAE',
'RMSLE'
'RMSLE',
'PSNR'
]
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
40 changes: 40 additions & 0 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce


def psnr(
pred: torch.Tensor,
target: torch.Tensor,
data_range: float = None,
base: float = 10.0,
reduction: str = 'elementwise_mean'
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio metric
Args:
pred: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)
reduction: method for reducing psnr (default: takes the mean)
Example:
>>> from pytorch_lightning.metrics.regression import PSNR
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""

if data_range is None:
data_range = max(target.max() - target.min(), pred.max() - pred.min())
else:
data_range = torch.tensor(float(data_range))
mse = F.mse_loss(pred.view(-1), target.view(-1), reduction=reduction)
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse)
return psnr_base_e * (10 / torch.log(torch.tensor(base)))
35 changes: 34 additions & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch.nn.functional as F
import torch
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.regression import (
psnr,
)

__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE']
__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE', 'PSNR']


class MSE(Metric):
Expand Down Expand Up @@ -187,3 +190,33 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the rmsle loss.
"""
return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction)


class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio metric
"""

def __init__(self, data_range: float = None, base: int = 10, reduction: str = 'elementwise_mean'):
"""
Args:
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)
reduction: method for reducing psnr (default: takes the mean)
Example:
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""
super().__init__(name='psnr')
self.data_range = data_range
self.base = float(base)
self.reduction = reduction

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return psnr(pred, target, self.data_range, self.base, self.reduction)
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# the default package dependencies

numpy>=1.15 # because some BLAS compilation issues
numpy>=1.16.4
torch>=1.3
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ flake8
flake8-black
check-manifest
twine==1.13.0

scikit-image
black==19.10b0
pre-commit>=1.0

Expand Down
37 changes: 37 additions & 0 deletions tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch

from skimage.metrics import peak_signal_noise_ratio as ski_psnr
from pytorch_lightning.metrics.functional.regression import psnr


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio')
])
def test_psnr_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred = torch.randint(10, (500,), device=device, dtype=torch.double)
target = torch.randint(10, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=10), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=10))

pred = torch.randint(5, (500,), device=device, dtype=torch.double)
target = torch.randint(10, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=10), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=10))

pred = torch.randint(10, (500,), device=device, dtype=torch.double)
target = torch.randint(5, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=5), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=5))
46 changes: 45 additions & 1 deletion tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
import numpy as np

from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE
MAE, MSE, RMSE, RMSLE, PSNR
)


Expand Down Expand Up @@ -64,3 +66,45 @@ def test_rmsle(pred, target, exp):

assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp


@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=3)
),
pytest.param(
[4., 3., 2., 1.],
[1., 4., 3., 2.],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=3)
)
])
def test_psnr(pred, target, exp):
psnr = PSNR()
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp


@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=4) * np.log(10)
),
pytest.param(
[4., 3., 2., 1.],
[1., 4., 3., 2.],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=4) * np.log(10)
)
])
def test_psnr_base_e_wider_range(pred, target, exp):
psnr = PSNR(data_range=4, base=2.718281828459045)
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp

0 comments on commit 1dc7242

Please sign in to comment.