-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add regression metrics * solve tests * add docs
- Loading branch information
Showing
3 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import torch.nn.functional as F | ||
import torch | ||
from pytorch_lightning.metrics.metric import Metric | ||
|
||
__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE'] | ||
|
||
|
||
class MSE(Metric): | ||
""" | ||
Computes the mean squared loss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reduction: str = 'elementwise_mean', | ||
): | ||
""" | ||
Args: | ||
reduction: a method for reducing mse over labels (default: takes the mean) | ||
Available reduction methods: | ||
- elementwise_mean: takes the mean | ||
- none: pass array | ||
- sum: add elements | ||
Example: | ||
>>> pred = torch.tensor([0., 1, 2, 3]) | ||
>>> target = torch.tensor([0., 1, 2, 2]) | ||
>>> metric = MSE() | ||
>>> metric(pred, target) | ||
tensor(0.2500) | ||
""" | ||
super().__init__(name='mse') | ||
if reduction == 'elementwise_mean': | ||
reduction = 'mean' | ||
self.reduction = reduction | ||
|
||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
Args: | ||
pred: predicted labels | ||
target: ground truth labels | ||
Return: | ||
A Tensor with the mse loss. | ||
""" | ||
return F.mse_loss(pred, target, self.reduction) | ||
|
||
|
||
class RMSE(Metric): | ||
""" | ||
Computes the root mean squared loss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reduction: str = 'elementwise_mean', | ||
): | ||
""" | ||
Args: | ||
reduction: a method for reducing mse over labels (default: takes the mean) | ||
Available reduction methods: | ||
- elementwise_mean: takes the mean | ||
- none: pass array | ||
- sum: add elements | ||
Example: | ||
>>> pred = torch.tensor([0., 1, 2, 3]) | ||
>>> target = torch.tensor([0., 1, 2, 2]) | ||
>>> metric = RMSE() | ||
>>> metric(pred, target) | ||
tensor(0.5000) | ||
""" | ||
super().__init__(name='rmse') | ||
if reduction == 'elementwise_mean': | ||
reduction = 'mean' | ||
self.reduction = reduction | ||
|
||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
Args: | ||
pred: predicted labels | ||
target: ground truth labels | ||
Return: | ||
A Tensor with the rmse loss. | ||
""" | ||
return torch.sqrt(F.mse_loss(pred, target, self.reduction)) | ||
|
||
|
||
class MAE(Metric): | ||
""" | ||
Computes the root mean absolute loss or L1-loss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reduction: str = 'elementwise_mean', | ||
): | ||
""" | ||
Args: | ||
reduction: a method for reducing mse over labels (default: takes the mean) | ||
Available reduction methods: | ||
- elementwise_mean: takes the mean | ||
- none: pass array | ||
- sum: add elements | ||
Example: | ||
>>> pred = torch.tensor([0., 1, 2, 3]) | ||
>>> target = torch.tensor([0., 1, 2, 2]) | ||
>>> metric = MAE() | ||
>>> metric(pred, target) | ||
tensor(0.2500) | ||
""" | ||
super().__init__(name='mae') | ||
if reduction == 'elementwise_mean': | ||
reduction = 'mean' | ||
self.reduction = reduction | ||
|
||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
Args: | ||
pred: predicted labels | ||
target: ground truth labels | ||
Return: | ||
A Tensor with the mae loss. | ||
""" | ||
return F.l1_loss(pred, target, self.reduction) | ||
|
||
|
||
class RMSLE(Metric): | ||
""" | ||
Computes the root mean squared log loss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reduction: str = 'elementwise_mean', | ||
): | ||
""" | ||
Args: | ||
reduction: a method for reducing mse over labels (default: takes the mean) | ||
Available reduction methods: | ||
- elementwise_mean: takes the mean | ||
- none: pass array | ||
- sum: add elements | ||
Example: | ||
>>> pred = torch.tensor([0., 1, 2, 3]) | ||
>>> target = torch.tensor([0., 1, 2, 2]) | ||
>>> metric = RMSLE() | ||
>>> metric(pred, target) | ||
tensor(0.0207) | ||
""" | ||
super().__init__(name='rmsle') | ||
if reduction == 'elementwise_mean': | ||
reduction = 'mean' | ||
self.reduction = reduction | ||
|
||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
Args: | ||
pred: predicted labels | ||
target: ground truth labels | ||
Return: | ||
A Tensor with the rmsle loss. | ||
""" | ||
return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
import torch | ||
|
||
from pytorch_lightning.metrics.regression import ( | ||
MAE, MSE, RMSE, RMSLE | ||
) | ||
|
||
|
||
@pytest.mark.parametrize(['pred', 'target', 'exp'], [ | ||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), | ||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 3.) | ||
]) | ||
def test_mse(pred, target, exp): | ||
mse = MSE() | ||
assert mse.name == 'mse' | ||
|
||
score = mse(pred=torch.tensor(pred), | ||
target=torch.tensor(target)) | ||
|
||
assert isinstance(score, torch.Tensor) | ||
assert score.item() == exp | ||
|
||
|
||
@pytest.mark.parametrize(['pred', 'target', 'exp'], [ | ||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .5), | ||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) | ||
]) | ||
def test_rmse(pred, target, exp): | ||
rmse = RMSE() | ||
assert rmse.name == 'rmse' | ||
|
||
score = rmse(pred=torch.tensor(pred), | ||
target=torch.tensor(target)) | ||
|
||
assert isinstance(score, torch.Tensor) | ||
assert pytest.approx(score.item(), rel=1e-3) == exp | ||
|
||
|
||
@pytest.mark.parametrize(['pred', 'target', 'exp'], [ | ||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), | ||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.5) | ||
]) | ||
def test_mae(pred, target, exp): | ||
mae = MAE() | ||
assert mae.name == 'mae' | ||
|
||
score = mae(pred=torch.tensor(pred), | ||
target=torch.tensor(target)) | ||
|
||
assert isinstance(score, torch.Tensor) | ||
assert score.item() == exp | ||
|
||
|
||
@pytest.mark.parametrize(['pred', 'target', 'exp'], [ | ||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .0207), | ||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], .2841) | ||
]) | ||
def test_rmsle(pred, target, exp): | ||
rmsle = RMSLE() | ||
assert rmsle.name == 'rmsle' | ||
|
||
score = rmsle(pred=torch.tensor(pred), | ||
target=torch.tensor(target)) | ||
|
||
assert isinstance(score, torch.Tensor) | ||
assert pytest.approx(score.item(), rel=1e-3) == exp |