Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSNR implementation #2483

Merged
merged 17 commits into from
Jul 8, 2020
4 changes: 3 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
RMSE,
MAE,
RMSLE,
PSNR
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down Expand Up @@ -50,6 +51,7 @@
'MSE',
'RMSE',
'MAE',
'RMSLE'
'RMSLE',
'PSNR'
]
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
38 changes: 38 additions & 0 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce

def psnr(
pred: torch.Tensor,
target: torch.Tensor,
data_range: float = None,
base: float = 10.0,
reduction: str = 'elementwise_mean'
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio metric

Args:
pred: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)
reduction: method for reducing psnr (default: takes the mean)

Example:

>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor([2.5527])
"""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

if data_range is None:
data_range = max(target.max() - target.min(), pred.max() - pred.min())
else:
data_range = torch.tensor(data_range)
mse = F.mse_loss(pred.view(-1), target.view(-1), reduction=reduction)
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse)
return psnr_base_e * (10 / torch.log(base))
37 changes: 36 additions & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from pytorch_lightning.metrics.metric import Metric

__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE']
__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE', 'PSNR']


class MSE(Metric):
Expand Down Expand Up @@ -187,3 +187,38 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the rmsle loss.
"""
return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction)


class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio metric
"""

def __init__(self, data_range: float = None, base: int = 10):
"""
Args:
data_range: the range of the data. If None, it is determined from the data (max - min).
base: a base of a logarithm to use (default: 10)


Example:

>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor([2.5527])
"""
super().__init__(name='psnr')
self.data_range = data_range
self.base = torch.tensor([float(base)])

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.data_range is None:
data_range = max(target.max() - target.min(), pred.max() - pred.min())
else:
data_range = torch.tensor(float(self.data_range))
mse = F.mse_loss(pred.view(-1), target.view(-1))
# numerical precision tricks
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse)
return psnr_base_e * (10 / torch.log(self.base)) # change the logarithm basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to replicate the whole procedure again here. Just import from functional and use that one. :) Also add reduction parameter here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored and added

BTW in case you didn't know, pytest says 'elementwise_mean' is deprecated, it might be smart to refactor the default to 'mean'. I left it that way to match rohitgr7's PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Also, can you fix the pep8 issues mentioned above by @pep8speaks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 I'm confused about the missing whitespace. I thought that was a bad thing :|

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, don't know about that :| I never use an extra , after the last element in the list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm I figured what was wrong

2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ flake8
flake8-black
check-manifest
twine==1.13.0

scikit-image
black==19.10b0
pre-commit>=1.0

Expand Down
45 changes: 44 additions & 1 deletion tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
import numpy as np

from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE
MAE, MSE, RMSE, RMSLE, PSNR
)


Expand Down Expand Up @@ -64,3 +66,44 @@ def test_rmsle(pred, target, exp):

assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp

@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=3)
),
pytest.param(
[4., 3., 2., 1.,],
[1., 4., 3., 2.,],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=3)
)
])
def test_psnr(pred, target, exp):
psnr = PSNR()
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp

@pytest.mark.parametrize(['pred', 'target', 'exp'], [
pytest.param(
[0., 1., 2., 3.],
[0., 1., 2., 2.],
ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=4) * np.log(10)
),
pytest.param(
[4., 3., 2., 1.,],
[1., 4., 3., 2.,],
ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=4) * np.log(10)
)
])

def test_psnr_base_e_wider_range(pred, target, exp):
psnr = PSNR(data_range=4, base=2.718281828459045)
assert psnr.name == 'psnr'
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target))
assert isinstance(score, torch.Tensor)
assert pytest.approx(score.item(), rel=1e-3) == exp