Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSNR implementation #2483

Merged
merged 17 commits into from
Jul 8, 2020
35 changes: 35 additions & 0 deletions pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,38 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the rmsle loss.
"""
return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction)

class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio metric
"""

def __init__(self, base: int = 10):
"""
Args:
base: a base of a logarithm to use (default: 10)


Example:

>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor([2.5527])
"""
self.base = torch.tensor(float(base))


def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
mse = F.mse_loss(pred.view(-1), torch.view(-1))

# The calculation is troublesome because it is dependant of the maximum value possible.
# For integer inputs that should not be a problem (it's 255) but for floats there is a problem
# because the floats can be in [0, 1] range or they can be normalized with unknown mean and variance.
# Since mean and variance are unknown, we cannot know what's the maximum value to use in calculation.
# This implementation, therefore, finds the maximum empirically.
maximum = max(torch.max(torch.abs(pred)), torch.max(torch.abs(target)))
PSNR_base_e = 2*torch.log(maximum) - torch.log(mse)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable names should always be lowercase according to pep8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

return PSNR_base_e * (10 / torch.log(self.base)) # change the logarithm basis