Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSNR implementation #2483

Merged
merged 17 commits into from
Jul 8, 2020
Merged
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [unreleased] - YYYY-MM-DD

### Added

- Added a new metric: peak signal-to-noise ratio ([#2483](https://github.com/PyTorchLightning/pytorch-lightning/pull/2483))
Borda marked this conversation as resolved.
Show resolved Hide resolved

### Changed

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- autopep8
- twine==1.13.0
- pillow<7.0.0
- scikit-image

# Optional
- scipy>=0.13.3
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
RMSE,
MAE,
RMSLE,
PSNR
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down Expand Up @@ -50,6 +51,7 @@
'MSE',
'RMSE',
'MAE',
'RMSLE'
'RMSLE',
'PSNR'
]
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
40 changes: 40 additions & 0 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce


def psnr(
pred: torch.Tensor,
target: torch.Tensor,
data_range: float = None,
base: float = 10.0,
reduction: str = 'elementwise_mean'
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio metric

Args:
pred: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)
reduction: method for reducing psnr (default: takes the mean)

Example:

>>> from pytorch_lightning.metrics.regression import PSNR
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""

if data_range is None:
data_range = max(target.max() - target.min(), pred.max() - pred.min())
else:
data_range = torch.tensor(float(data_range))
mse = F.mse_loss(pred.view(-1), target.view(-1), reduction=reduction)
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse)
return psnr_base_e * (10 / torch.log(torch.tensor(base)))
35 changes: 34 additions & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch.nn.functional as F
import torch
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.regression import (
psnr,
)

__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE']
__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE', 'PSNR']


class MSE(Metric):
Expand Down Expand Up @@ -187,3 +190,33 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the rmsle loss.
"""
return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction)


class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio metric
"""

def __init__(self, data_range: float = None, base: int = 10, reduction: str = 'elementwise_mean'):
"""
Args:
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)
reduction: method for reducing psnr (default: takes the mean)


Example:

>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""
super().__init__(name='psnr')
self.data_range = data_range
self.base = float(base)
self.reduction = reduction

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return psnr(pred, target, self.data_range, self.base, self.reduction)
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# the default package dependencies

numpy>=1.15 # because some BLAS compilation issues
numpy>=1.16.4
torch>=1.3
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ flake8
flake8-black
check-manifest
twine==1.13.0

scikit-image
black==19.10b0
pre-commit>=1.0

Expand Down
37 changes: 37 additions & 0 deletions tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch

from skimage.metrics import peak_signal_noise_ratio as ski_psnr
from pytorch_lightning.metrics.functional.regression import psnr


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio')
])
def test_psnr_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred = torch.randint(10, (500,), device=device, dtype=torch.double)
target = torch.randint(10, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=10), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=10))

pred = torch.randint(5, (500,), device=device, dtype=torch.double)
target = torch.randint(10, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=10), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=10))

pred = torch.randint(10, (500,), device=device, dtype=torch.double)
target = torch.randint(5, (500,), device=device, dtype=torch.double)
assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=5), dtype=torch.double, device=device),
torch_metric(pred, target, data_range=5))
46 changes: 45 additions & 1 deletion tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
import numpy as np

from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE
MAE, MSE, RMSE, RMSLE, PSNR
)


Expand Down Expand Up @@ -64,3 +66,45 @@ def test_rmsle(pred, target, exp):

assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp


@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=3)
),
pytest.param(
[4., 3., 2., 1.],
[1., 4., 3., 2.],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=3)
)
])
def test_psnr(pred, target, exp):
psnr = PSNR()
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp


@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=4) * np.log(10)
),
pytest.param(
[4., 3., 2., 1.],
[1., 4., 3., 2.],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=4) * np.log(10)
)
])
def test_psnr_base_e_wider_range(pred, target, exp):
psnr = PSNR(data_range=4, base=2.718281828459045)
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp