From 3eb79fb4fef33e90cbc0a5d110352063ff640ba7 Mon Sep 17 00:00:00 2001 From: Xavier Sumba Date: Tue, 16 Jun 2020 19:41:31 -0400 Subject: [PATCH 1/3] add regression metrics --- pytorch_lightning/metrics/regression.py | 188 ++++++++++++++++++++++++ tests/metrics/test_regression.py | 63 ++++++++ 2 files changed, 251 insertions(+) create mode 100644 pytorch_lightning/metrics/regression.py create mode 100644 tests/metrics/test_regression.py diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py new file mode 100644 index 0000000000000..e91cb61735ef4 --- /dev/null +++ b/pytorch_lightning/metrics/regression.py @@ -0,0 +1,188 @@ +import torch.nn.functional as F +import torch +from pytorch_lightning.metrics.metric import Metric + +__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE'] + +class MSE(Metric): + """ + Computes the mean squared loss. + """ + + def __init__( + self, + reduction: str = 'elementwise_mean', + ): + """ + Args: + reduction: a method for reducing mse over labels (default: takes the mean) + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + + Example: + + >>> pred = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> metric = MSE() + >>> metric(pred, target) + tensor(0.2500) + + """ + super().__init__(name='mse') + if reduction == 'elementwise_mean': + reduction = 'mean' + self.reduction = reduction + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Actual metric computation + + Args: + pred: predicted labels + target: ground truth labels + + Return: + A Tensor with the mse loss. + """ + return F.mse_loss(pred, target, self.reduction) + + +class RMSE(Metric): + """ + Computes the root mean squared loss. + """ + + def __init__( + self, + reduction: str = 'elementwise_mean', + ): + """ + Args: + reduction: a method for reducing mse over labels (default: takes the mean) + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + + Example: + + >>> pred = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> metric = RMSE() + >>> metric(pred, target) + tensor(0.5000) + + """ + super().__init__(name='rmse') + if reduction == 'elementwise_mean': + reduction = 'mean' + self.reduction = reduction + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Actual metric computation + + Args: + pred: predicted labels + target: ground truth labels + + Return: + A Tensor with the rmse loss. + """ + return torch.sqrt(F.mse_loss(pred, target, self.reduction)) + + +class MAE(Metric): + """ + Computes the root mean absolute loss or L1-loss. + """ + + def __init__( + self, + reduction: str = 'elementwise_mean', + ): + """ + Args: + reduction: a method for reducing mse over labels (default: takes the mean) + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + + Example: + + >>> pred = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> metric = MAE() + >>> metric(pred, target) + tensor(0.2500) + + """ + super().__init__(name='mae') + if reduction == 'elementwise_mean': + reduction = 'mean' + self.reduction = reduction + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Actual metric computation + + Args: + pred: predicted labels + target: ground truth labels + + Return: + A Tensor with the mae loss. + """ + return F.l1_loss(pred, target, self.reduction) + + +class RMSLE(Metric): + """ + Computes the root mean squared log loss. + """ + + def __init__( + self, + reduction: str = 'elementwise_mean', + ): + """ + Args: + reduction: a method for reducing mse over labels (default: takes the mean) + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + + Example: + + >>> pred = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> metric = RMSLE() + >>> metric(pred, target) + tensor(0.0207) + + """ + super().__init__(name='rmsle') + if reduction == 'elementwise_mean': + reduction = 'mean' + self.reduction = reduction + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Actual metric computation + + Args: + pred: predicted labels + target: ground truth labels + + Return: + A Tensor with the rmsle loss. + """ + return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction) diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py new file mode 100644 index 0000000000000..d4a99de821b92 --- /dev/null +++ b/tests/metrics/test_regression.py @@ -0,0 +1,63 @@ +import pytest +import torch + +from pytorch_lightning.metrics.regression import ( + MAE, MSE, RMSE, RMSLE +) + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 3.) +]) +def test_mse(pred, target, exp): + mse = MSE() + assert mse.name == 'mse' + + score = mse(pred=torch.tensor(pred), + target=torch.tensor(target)) + + assert isinstance(score, torch.Tensor) + assert score.item() == pytest.approx(exp,1e-4) + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .5), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) +]) +def test_rmse(pred, target, exp): + rmse = RMSE() + assert rmse.name == 'rmse' + + score = rmse(pred=torch.tensor(pred), + target=torch.tensor(target)) + + assert isinstance(score, torch.Tensor) + assert score.item() == pytest.approx(exp,1e-4) + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) +]) +def test_mae(pred, target, exp): + mae = MAE() + assert mae.name == 'mae' + + score = mae(pred=torch.tensor(pred), + target=torch.tensor(target)) + + assert isinstance(score, torch.Tensor) + assert score.item() == pytest.approx(exp,1e-4) + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .0207), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) +]) +def test_rmsle(pred, target, exp): + rmsle = RMSLE() + assert rmsle.name == 'rmsle' + + score = rmsle(pred=torch.tensor(pred), + target=torch.tensor(target)) + + assert isinstance(score, torch.Tensor) + assert pytest.approx(score.item(),1e-4) == exp + From 4c3d3a01f2f1fc95c9cf8ba41cb24e2db8288454 Mon Sep 17 00:00:00 2001 From: Xavier Sumba Date: Wed, 17 Jun 2020 09:38:35 -0400 Subject: [PATCH 2/3] solve tests --- pytorch_lightning/metrics/regression.py | 1 + tests/metrics/test_regression.py | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index e91cb61735ef4..964d88cd9604f 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -4,6 +4,7 @@ __all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE'] + class MSE(Metric): """ Computes the mean squared loss. diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index d4a99de821b92..5675051cff090 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -5,6 +5,7 @@ MAE, MSE, RMSE, RMSLE ) + @pytest.mark.parametrize(['pred', 'target', 'exp'], [ pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 3.) @@ -14,10 +15,11 @@ def test_mse(pred, target, exp): assert mse.name == 'mse' score = mse(pred=torch.tensor(pred), - target=torch.tensor(target)) + target=torch.tensor(target)) assert isinstance(score, torch.Tensor) - assert score.item() == pytest.approx(exp,1e-4) + assert score.item() == exp + @pytest.mark.parametrize(['pred', 'target', 'exp'], [ pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .5), @@ -28,14 +30,15 @@ def test_rmse(pred, target, exp): assert rmse.name == 'rmse' score = rmse(pred=torch.tensor(pred), - target=torch.tensor(target)) + target=torch.tensor(target)) assert isinstance(score, torch.Tensor) - assert score.item() == pytest.approx(exp,1e-4) + assert pytest.approx(score.item(), rel=1e-3) == exp + @pytest.mark.parametrize(['pred', 'target', 'exp'], [ pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .25), - pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.5) ]) def test_mae(pred, target, exp): mae = MAE() @@ -45,19 +48,19 @@ def test_mae(pred, target, exp): target=torch.tensor(target)) assert isinstance(score, torch.Tensor) - assert score.item() == pytest.approx(exp,1e-4) + assert score.item() == exp + @pytest.mark.parametrize(['pred', 'target', 'exp'], [ pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.], .0207), - pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], 1.7321) + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.], .2841) ]) def test_rmsle(pred, target, exp): rmsle = RMSLE() assert rmsle.name == 'rmsle' score = rmsle(pred=torch.tensor(pred), - target=torch.tensor(target)) + target=torch.tensor(target)) assert isinstance(score, torch.Tensor) - assert pytest.approx(score.item(),1e-4) == exp - + assert pytest.approx(score.item(), rel=1e-3) == exp From 33cebcc1e12172aee4635fec574642ebf94909ce Mon Sep 17 00:00:00 2001 From: Xavier Sumba Date: Wed, 17 Jun 2020 13:14:39 -0400 Subject: [PATCH 3/3] add docs --- docs/source/metrics.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 8d7322d4702b9..81fa315aa8216 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -180,6 +180,18 @@ ROC .. autoclass:: pytorch_lightning.metrics.classification.ROC :noindex: +MAE +^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.MAE + :noindex: + +MSE +^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.MSE + :noindex: + MulticlassROC ^^^^^^^^^^^^^ @@ -192,6 +204,18 @@ MulticlassPrecisionRecall .. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall :noindex: +RMSE +^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.RMSLE + :noindex: + +RMSLE +^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.RMSE + :noindex: + -------------- Functional Metrics