-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics docs #2184
Metrics docs #2184
Changes from all commits
acbc1be
90e3c0d
a845179
70406bd
d8194f1
1c75220
5a9ddb2
8386615
da1f5ab
b7480e4
38777d3
6d51f55
a501d8b
dce2473
0f6eafa
5aa02eb
3f0ee25
b4d648f
1e5d7c0
16d8f0d
bcaa743
69bd7cb
ec277e9
0519206
d8a51b0
5be3a8b
faf362b
eef0c54
43876c0
992409a
94b1382
f946650
c0aea3f
bc9007d
55409fe
30c8148
4d70dcf
46c8c6f
e7f9b85
d3e792c
596563a
3b0c296
41be1a7
c7c88cb
fef6572
414dd81
8d2f9c5
de713fe
1f586bf
2c6ed8b
2422461
bb9c258
ae601a6
8ecc4a0
705cf63
022625e
babb117
6204316
2d9b90b
fb26791
14cf60a
d2d963d
464f367
95446e4
a32fcf4
5fbda6e
339814f
8d9f53a
8ce3687
2a11c3e
d71c33a
7472031
40bd93d
78e1198
ca4392a
be8d587
e9a90a0
2e832af
db0dc42
e7d8e50
a85818c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,318 @@ | ||
.. automodule:: pytorch_lightning.metrics | ||
:members: | ||
:noindex: | ||
:exclude-members: | ||
.. testsetup:: * | ||
|
||
from torch.nn import Module | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.metrics import TensorMetric, NumpyMetric | ||
|
||
Metrics | ||
======= | ||
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code. | ||
Metrics are used to monitor model performance. | ||
|
||
In this package we provide two major pieces of functionality. | ||
|
||
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic. | ||
2. A collection of popular metrics already implemented for you. | ||
|
||
Example:: | ||
|
||
from pytorch_lightning.metrics.functional import accuracy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these imports wrong? |
||
|
||
pred = torch.tensor([0, 1, 2, 3]) | ||
target = torch.tensor([0, 1, 2, 2]) | ||
|
||
# calculates accuracy across all GPUs and all Nodes used in training | ||
accuracy(pred, target) | ||
|
||
Out:: | ||
|
||
tensor(0.7500) | ||
Comment on lines
+27
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doe s not make sense |
||
|
||
-------------- | ||
|
||
Implement a metric | ||
------------------ | ||
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics | ||
will slow down training, use PyTorch metrics when possible. | ||
|
||
Use :class:`TensorMetric` to implement native PyTorch metrics. This class | ||
handles automated DDP syncing and converts all inputs and outputs to tensors. | ||
|
||
Use :class:`NumpyMetric` to implement numpy metrics. This class | ||
handles automated DDP syncing and converts all inputs and outputs to tensors. | ||
|
||
.. warning:: | ||
Numpy metrics might slow down your training substantially, | ||
since every metric computation requires a GPU sync to convert tensors to numpy. | ||
|
||
TensorMetric | ||
^^^^^^^^^^^^ | ||
Here's an example showing how to implement a TensorMetric | ||
|
||
.. testcode:: | ||
|
||
class RMSE(TensorMetric): | ||
def forward(self, x, y): | ||
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0))) | ||
|
||
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric | ||
:noindex: | ||
|
||
NumpyMetric | ||
^^^^^^^^^^^ | ||
Here's an example showing how to implement a NumpyMetric | ||
|
||
.. testcode:: | ||
|
||
class RMSE(NumpyMetric): | ||
def forward(self, x, y): | ||
return np.sqrt(np.mean(np.power(x-y, 2.0))) | ||
|
||
|
||
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric | ||
:noindex: | ||
|
||
-------------- | ||
|
||
Class Metrics | ||
------------- | ||
The following are metrics which can be instantiated as part of a module definition (even with just | ||
plain PyTorch). | ||
|
||
.. testcode:: | ||
|
||
from pytorch_lightning.metrics import Accuracy | ||
|
||
# Plain PyTorch | ||
class MyModule(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.metric = Accuracy() | ||
|
||
def forward(self, x, y): | ||
y_hat = ... | ||
acc = self.metric(y_hat, y) | ||
|
||
# PyTorch Lightning | ||
class MyModule(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.metric = Accuracy() | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = ... | ||
acc = self.metric(y_hat, y) | ||
|
||
These metrics even work when using distributed training: | ||
|
||
.. code-block:: python | ||
|
||
model = MyModule() | ||
trainer = Trainer(gpus=8, num_nodes=2) | ||
|
||
# any metric automatically reduces across GPUs (even the ones you implement using Lightning) | ||
trainer.fit(model) | ||
|
||
Accuracy | ||
^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy | ||
:noindex: | ||
|
||
AveragePrecision | ||
^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision | ||
:noindex: | ||
|
||
AUROC | ||
^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.AUROC | ||
:noindex: | ||
|
||
ConfusionMatrix | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix | ||
:noindex: | ||
|
||
DiceCoefficient | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient | ||
:noindex: | ||
|
||
F1 | ||
^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.F1 | ||
:noindex: | ||
|
||
FBeta | ||
^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.FBeta | ||
:noindex: | ||
|
||
PrecisionRecall | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall | ||
:noindex: | ||
|
||
Precision | ||
^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Precision | ||
:noindex: | ||
|
||
Recall | ||
^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Recall | ||
:noindex: | ||
|
||
ROC | ||
^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.ROC | ||
:noindex: | ||
|
||
MulticlassROC | ||
^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC | ||
:noindex: | ||
|
||
MulticlassPrecisionRecall | ||
^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall | ||
:noindex: | ||
|
||
-------------- | ||
|
||
Functional Metrics | ||
------------------ | ||
|
||
accuracy (F) | ||
^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.accuracy | ||
:noindex: | ||
|
||
auc (F) | ||
^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.auc | ||
:noindex: | ||
|
||
auroc (F) | ||
^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.auroc | ||
:noindex: | ||
|
||
average_precision (F) | ||
^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.average_precision | ||
:noindex: | ||
|
||
confusion_matrix (F) | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix | ||
:noindex: | ||
|
||
dice_score (F) | ||
^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.dice_score | ||
:noindex: | ||
|
||
f1_score (F) | ||
^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.f1_score | ||
:noindex: | ||
|
||
fbeta_score (F) | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score | ||
:noindex: | ||
|
||
multiclass_precision_recall_curve (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve | ||
:noindex: | ||
|
||
multiclass_roc (F) | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc | ||
:noindex: | ||
|
||
precision (F) | ||
^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision | ||
:noindex: | ||
|
||
precision_recall (F) | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall | ||
:noindex: | ||
|
||
precision_recall_curve (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve | ||
:noindex: | ||
|
||
recall (F) | ||
^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.recall | ||
:noindex: | ||
|
||
roc (F) | ||
^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.roc | ||
:noindex: | ||
|
||
stat_scores (F) | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores | ||
:noindex: | ||
|
||
stat_scores_multiple_classes (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes | ||
:noindex: | ||
|
||
---------------- | ||
|
||
Metric pre-processing | ||
--------------------- | ||
Metric | ||
|
||
to_categorical (F) | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical | ||
:noindex: | ||
|
||
to_onehot (F) | ||
^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot | ||
:noindex: |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,15 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep a short version here... |
||
Metrics | ||
======= | ||
|
||
Metrics are generally used to monitor model performance. | ||
|
||
The following package aims to provide the most convenient ones as well | ||
as a structure to implement your custom metrics for all the fancy research | ||
you want to do. | ||
|
||
For native PyTorch implementations of metrics, it is recommended to use | ||
the :class:`TensorMetric` which handles automated DDP syncing and conversions | ||
to tensors for all inputs and outputs. | ||
|
||
If your metrics implementation works on numpy, just use the | ||
:class:`NumpyMetric`, which handles the automated conversion of | ||
inputs to and outputs from numpy as well as automated ddp syncing. | ||
|
||
.. warning:: Employing numpy in your metric calculation might slow | ||
down your training substantially, since every metric computation | ||
requires a GPU sync to convert tensors to numpy. | ||
|
||
|
||
""" | ||
|
||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric | ||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric | ||
from pytorch_lightning.metrics.sklearn import ( | ||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta, | ||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC) | ||
SklearnMetric, | ||
Accuracy, | ||
AveragePrecision, | ||
AUC, | ||
ConfusionMatrix, | ||
F1, | ||
FBeta, | ||
Precision, | ||
Recall, | ||
PrecisionRecallCurve, | ||
ROC, | ||
AUROC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mention that the package also comes with a interface to sklearn metrics (with a warning that this is slow due to casting back-and-forth of the tensors)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should. We should also include that sklearn has to be installed separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make it clear how to import the different backends. Maybe something like:
Use native backend
import pytorch_lightning.metrics.native as plm
Use sklearn backend
import pytorch_lightning.metrics.sklearn as plm
Use default (native if available else sklearn)
import pytorch_lightning.metrics as plm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SkafteNicki can you add the sklearn details in here?