diff --git a/CHANGELOG.md b/CHANGELOG.md index 3814d5349f675..085335e4ca090 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505)) + + ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst index 1fa5b64098e52..74a4a15deb2be 100644 --- a/docs/source/extensions/metrics.rst +++ b/docs/source/extensions/metrics.rst @@ -1,911 +1,9 @@ -.. testsetup:: * - - import torch - from torch.nn import Module - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.metrics import Metric - -.. _metrics: - ####### Metrics ####### -``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in -PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of -common metric implementations. - -The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits -``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class -serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the -provided input. +``pytorch_lightning.metrics`` has been moved to a separate package `TorchMetrics `_. +We will preserve compatibility for the next few releases, nevertheless, we encourage users to update to use this stand-alone package. .. warning:: - From v1.2 onward ``compute()`` will no longer automatically call ``reset()``, - and it is up to the user to reset metrics between epochs, except in the case where the - metric is directly passed to ``LightningModule``'s ``self.log``. - -These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in -distributed mode, the internal state of each metric is synced and reduced across each process, so that the -logic present in ``.compute()`` is applied to state information from all processes. - -The example below shows how to use a metric in your ``LightningModule``: - -.. code-block:: python - - def __init__(self): - ... - self.accuracy = pl.metrics.Accuracy() - - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - # log step metric - self.log('train_acc_step', self.accuracy(preds, y)) - ... - - def training_epoch_end(self, outs): - # log epoch metric - self.log('train_acc_epoch', self.accuracy.compute()) - - -``Metric`` objects can also be directly logged, in which case Lightning will log -the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. -If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling -``.compute()``. - -.. note:: - ``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx`` - flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class - contains its own distributed synchronization logic. - - This however is only true for metrics that inherit the base class ``Metric``, - and thus the functional metric API provides no support for in-built distributed synchronization - or reduction functions. - - -.. code-block:: python - - def __init__(self): - ... - self.train_acc = pl.metrics.Accuracy() - self.valid_acc = pl.metrics.Accuracy() - - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - self.train_acc(preds, y) - self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) - - def validation_step(self, batch, batch_idx): - logits = self(x) - ... - self.valid_acc(logits, y) - self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) - -.. note:: - - If using metrics in data parallel mode (dp), the metric update/logging should be done - in the ``_step_end`` method (where ```` is either ``training``, ``validation`` - or ``test``). This is due to metric states else being destroyed after each forward pass, - leading to wrong accumulation. In practice do the following: - - .. code-block:: python - - def training_step(self, batch, batch_idx): - data, target = batch - preds = self(data) - ... - return {'loss' : loss, 'preds' : preds, 'target' : target} - - def training_step_end(self, outputs): - #update and log - self.metric(outputs['preds'], outputs['target']) - self.log('metric', self.metric) - -This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: - -.. code-block:: python - - from pytorch_lightning import metrics - - train_accuracy = metrics.Accuracy() - valid_accuracy = metrics.Accuracy(compute_on_step=False) - - for epoch in range(epochs): - for x, y in train_data: - y_hat = model(x) - - # training step accuracy - batch_acc = train_accuracy(y_hat, y) - - for x, y in valid_data: - y_hat = model(x) - valid_accuracy(y_hat, y) - - # total accuracy over all training batches - total_train_accuracy = train_accuracy.compute() - - # total accuracy over all validation batches - total_valid_accuracy = valid_accuracy.compute() - -.. note:: - - Metrics contain internal states that keep track of the data seen so far. - Do not mix metric states across training, validation and testing. - It is highly recommended to re-initialize the metric per mode as - shown in the examples above. For easy initializing the same metric multiple - times, the ``.clone()`` method can be used: - - .. testcode:: - - from pytorch_lightning.metrics import Accuracy - - def __init__(self): - ... - metric = Accuracy() - self.train_acc = metric.clone() - self.val_acc = metric.clone() - self.test_acc = metric.clone() - -.. note:: - - Metric states are **not** added to the models ``state_dict`` by default. - To change this, after initializing the metric, the method ``.persistent(mode)`` can - be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. - -******************* -Metrics and devices -******************* - -Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave -similar to buffers and parameters of modules. This means that metrics states should -be moved to the same device as the input of the metric: - -.. code-block:: python - - from pytorch_lightning.metrics import Accuracy - - target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0)) - preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0)) - - # Metric states are always initialized on cpu, and needs to be moved to - # the correct device - confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0)) - out = confmat(preds, target) - print(out.device) # cuda:0 - -However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule` -, Lightning will automatically move the metrics to the same device as the data. Being -**properly defined** means that the metric is correctly identified as a child module of the -model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed -in native python ``list`` and ``dict``, as they will not be correctly identified -as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of -``dict`` use :class:`~torch.nn.ModuleDict`. - -.. testcode:: - - from pytorch_lightning.metrics import Accuracy - - class MyModule(LightningModule): - def __init__(self): - ... - # valid ways metrics will be identified as child modules - self.metric1 = Accuracy() - self.metric2 = nn.ModuleList(Accuracy()) - self.metric3 = nn.ModuleDict({'accuracy': Accuracy()}) - - def training_step(self, batch, batch_idx): - # all metrics will be on the same device as the input batch - data, target = batch - preds = self(data) - ... - val1 = self.metric1(preds, target) - val2 = self.metric2[0](preds, target) - val3 = self.metric3['accuracy'](preds, target) - - -********************* -Implementing a Metric -********************* - -To implement your custom metric, subclass the base ``Metric`` class and implement the following methods: - -- ``__init__()``: Each state variable should be called using ``self.add_state(...)``. -- ``update()``: Any code needed to update the state given any inputs to the metric. -- ``compute()``: Computes a final value from the state of the metric. - -All you need to do is call ``add_state`` correctly to implement a custom metric with DDP. -``reset()`` is called on metric state variables added using ``add_state()``. - -To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs -from the base ``Metric`` class. - -Example implementation: - -.. testcode:: - - from pytorch_lightning.metrics import Metric - - class MyAccuracy(Metric): - def __init__(self, dist_sync_on_step=False): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - preds, target = self._input_format(preds, target) - assert preds.shape == target.shape - - self.correct += torch.sum(preds == target) - self.total += target.numel() - - def compute(self): - return self.correct.float() / self.total - -Metrics support backpropagation, if all computations involved in the metric calculation -are differentiable. However, note that the cached state is detached from the computational -graph and cannot be backpropagated. Not doing this would mean storing the computational -graph for each update call, which can lead to out-of-memory errors. -In practise this means that: - -.. code-block:: python - - metric = MyMetric() - val = metric(pred, target) # this value can be backpropagated - val = metric.compute() # this value cannot be backpropagated - - -Metric API ----------- - -.. autoclass:: pytorch_lightning.metrics.Metric - :noindex: - -Internal implementation details -------------------------------- - -This section briefly describe how metrics work internally. We encourage looking at the source code for more info. -Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically -synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the -following internally: - -1. Clears computed cache -2. Calls user-defined ``update()`` - -Simiarly, calling ``compute()`` does the following internally - -1. Syncs metric states between processes -2. Reduce gathered metric states -3. Calls the user defined ``compute()`` method on the gathered metric states -4. Cache computed result - -From a user's standpoint this has one important side-effect: computed results are cached. This means that no -matter how many times ``compute`` is called after one and another, it will continue to return the same result. -The cache is first emptied on the next call to ``update``. - -``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal -metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls -to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``): - -1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches) -2. Caches the global state -3. Calls ``reset()`` to clear global metric state -4. Calls ``update()`` to update local metric state -5. Calls ``compute()`` to calculate metric for current batch -6. Restores the global state - -This procedure has the consequence of calling the user defined ``update`` **twice** during a single -forward call (one to update global statistics and one for getting the batch statistics). - - -****************** -Metric Arithmetics -****************** - -Metrics support most of python built-in operators for arithmetic, logic and bitwise operations. - -For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. -It can now be done with: - -.. code-block:: python - - first_metric = MyFirstMetric() - second_metric = MySecondMetric() - - new_metric = first_metric + second_metric - -``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but -forwards only the keyword arguments that are available in respective metric's update declaration. - -Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up. - -This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats): - -* Addition (``a + b``) -* Bitwise AND (``a & b``) -* Equality (``a == b``) -* Floordivision (``a // b``) -* Greater Equal (``a >= b``) -* Greater (``a > b``) -* Less Equal (``a <= b``) -* Less (``a < b``) -* Matrix Multiplication (``a @ b``) -* Modulo (``a % b``) -* Multiplication (``a * b``) -* Inequality (``a != b``) -* Bitwise OR (``a | b``) -* Power (``a ** b``) -* Substraction (``a - b``) -* True Division (``a / b``) -* Bitwise XOR (``a ^ b``) -* Absolute Value (``abs(a)``) -* Inversion (``~a``) -* Negative Value (``neg(a)``) -* Positive Value (``pos(a)``) - -**************** -MetricCollection -**************** - -In many cases it is beneficial to evaluate the model output by multiple metrics. -In this case the `MetricCollection` class may come in handy. It accepts a sequence -of metrics and wraps theses into a single callable metric class, with the same -interface as any other metric. - -Example: - -.. testcode:: - - from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - metric_collection = MetricCollection([ - Accuracy(), - Precision(num_classes=3, average='macro'), - Recall(num_classes=3, average='macro') - ]) - print(metric_collection(preds, target)) - -.. testoutput:: - :options: +NORMALIZE_WHITESPACE - - {'Accuracy': tensor(0.1250), - 'Precision': tensor(0.0667), - 'Recall': tensor(0.1111)} - -Similarly it can also reduce the amount of code required to log multiple metrics -inside your LightningModule - -.. code-block:: python - - def __init__(self): - ... - metrics = pl.metrics.MetricCollection(...) - self.train_metrics = metrics.clone() - self.valid_metrics = metrics.clone() - - def training_step(self, batch, batch_idx): - logits = self(x) - ... - self.train_metrics(logits, y) - # use log_dict instead of log - self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') - - def validation_step(self, batch, batch_idx): - logits = self(x) - ... - self.valid_metrics(logits, y) - # use log_dict instead of log - self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') - -.. note:: - - `MetricCollection` as default assumes that all the metrics in the collection - have the same call signature. If this is not the case, input that should be - given to different metrics can given as keyword arguments to the collection. - -.. autoclass:: pytorch_lightning.metrics.MetricCollection - :noindex: - - -*************************** -Class vs Functional Metrics -*************************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. - -********************** -Classification Metrics -********************** - -Input types ------------ - -For the purposes of classification metrics, inputs (predictions and targets) are split -into these categories (``N`` stands for the batch size and ``C`` for number of classes): - -.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 - :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" - :widths: 20, 10, 10, 10, 10 - - "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" - "Multi-class", "(N,)", "``int``", "(N,)", "``int``" - "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" - "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" - "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" - "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" - -.. note:: - All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so - that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. - -When predictions or targets are integers, it is assumed that class labels start at 0, i.e. -the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types - -.. testcode:: - - # Binary inputs - binary_preds = torch.tensor([0.6, 0.1, 0.9]) - binary_target = torch.tensor([1, 0, 2]) - - # Multi-class inputs - mc_preds = torch.tensor([0, 2, 1]) - mc_target = torch.tensor([0, 1, 2]) - - # Multi-class inputs with probabilities - mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) - mc_target_probs = torch.tensor([0, 1, 2]) - - # Multi-label inputs - ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) - ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) - - -Using the is_multiclass parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In some cases, you might have inputs which appear to be (multi-dimensional) multi-class -but are actually binary/multi-label - for example, if both predictions and targets are -integer (binary) tensors. Or it could be the other way around, you want to treat -binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. - -For these cases, the metrics where this distinction would make a difference, expose the -``is_multiclass`` argument. Let's see how this is used on the example of -:class:`~pytorch_lightning.metrics.StatScores` metric. - -First, let's consider the case with label predictions with 2 classes, which we want to -treat as binary. - -.. testcode:: - - from pytorch_lightning.metrics.functional import stat_scores - - # These inputs are supposed to be binary, but appear as multi-class - preds = torch.tensor([0, 1, 0]) - target = torch.tensor([1, 1, 0]) - -As you can see below, by default the inputs are treated -as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary - -which is the same as converting the predictions to float beforehand. - -.. doctest:: - - >>> stat_scores(preds, target, reduce='macro', num_classes=2) - tensor([[1, 1, 1, 0, 1], - [1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False) - tensor([[1, 0, 1, 1, 2]]) - >>> stat_scores(preds.float(), target, reduce='macro', num_classes=1) - tensor([[1, 0, 1, 1, 2]]) - -Next, consider the opposite example: inputs are binary (as predictions are probabilities), -but we would like to treat them as 2-class multi-class, to obtain the metric for both classes. - -.. testcode:: - - preds = torch.tensor([0.2, 0.7, 0.3]) - target = torch.tensor([1, 1, 0]) - -In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class. - -.. doctest:: - - >>> stat_scores(preds, target, reduce='macro', num_classes=1) - tensor([[1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True) - tensor([[1, 1, 1, 0, 1], - [1, 0, 1, 1, 2]]) - - -Class Metrics (Classification) ------------------------------- - -Accuracy -~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Accuracy - :noindex: - -AveragePrecision -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.AveragePrecision - :noindex: - -AUC -~~~ - -.. autoclass:: pytorch_lightning.metrics.AUC - :noindex: - -AUROC -~~~~~ - -.. autoclass:: pytorch_lightning.metrics.AUROC - :noindex: - -ConfusionMatrix -~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.ConfusionMatrix - :noindex: - -F1 -~~ - -.. autoclass:: pytorch_lightning.metrics.F1 - :noindex: - -FBeta -~~~~~ - -.. autoclass:: pytorch_lightning.metrics.FBeta - :noindex: - -IoU -~~~ - -.. autoclass:: pytorch_lightning.metrics.IoU - :noindex: - -Hamming Distance -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.HammingDistance - :noindex: - -Precision -~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Precision - :noindex: - -PrecisionRecallCurve -~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.PrecisionRecallCurve - :noindex: - -Recall -~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Recall - :noindex: - -ROC -~~~ - -.. autoclass:: pytorch_lightning.metrics.ROC - :noindex: - - -StatScores -~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.StatScores - :noindex: - - -Functional Metrics (Classification) ------------------------------------ - -accuracy [func] -~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.accuracy - :noindex: - - -auc [func] -~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.auc - :noindex: - - -auroc [func] -~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.auroc - :noindex: - - -average_precision [func] -~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.average_precision - :noindex: - - -confusion_matrix [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix - :noindex: - - -dice_score [func] -~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.dice_score - :noindex: - - -f1 [func] -~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.f1 - :noindex: - - -fbeta [func] -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.fbeta - :noindex: - -hamming_distance [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance - :noindex: - -iou [func] -~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.iou - :noindex: - - -roc [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.roc - :noindex: - - -precision [func] -~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision - :noindex: - - -precision_recall [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall - :noindex: - - -precision_recall_curve [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve - :noindex: - - -recall [func] -~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.recall - :noindex: - -select_topk [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.select_topk - :noindex: - - -stat_scores [func] -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores - :noindex: - - -stat_scores_multiple_classes [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes - :noindex: - - -to_categorical [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.to_categorical - :noindex: - - -to_onehot [func] -~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.to_onehot - :noindex: - -****************** -Regression Metrics -****************** - -Class Metrics (Regression) --------------------------- - -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.ExplainedVariance - :noindex: - - -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanAbsoluteError - :noindex: - - -MeanSquaredError -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanSquaredError - :noindex: - - -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanSquaredLogError - :noindex: - - -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.PSNR - :noindex: - - -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.SSIM - :noindex: - - -R2Score -~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.R2Score - :noindex: - -Functional Metrics (Regression) -------------------------------- - -explained_variance [func] -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.explained_variance - :noindex: - - -image_gradients [func] -~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.image_gradients - :noindex: - - -mean_absolute_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_absolute_error - :noindex: - - -mean_squared_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_error - :noindex: - - -mean_squared_log_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error - :noindex: - - -psnr [func] -~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.psnr - :noindex: - - -ssim [func] -~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.ssim - :noindex: - - -r2score [func] -~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.r2score - :noindex: - - -*** -NLP -*** - -bleu_score [func] ------------------ - -.. autofunction:: pytorch_lightning.metrics.functional.bleu_score - :noindex: - -***************************** -Information Retrieval Metrics -***************************** - -Class Metrics (IR) ------------------- - -Mean Average Precision -~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.retrieval.RetrievalMAP - :noindex: - - -Functional Metrics (IR) ------------------------ - -average_precision_retrieval [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.ir_average_precision.retrieval_average_precision - :noindex: - - -******** -Pairwise -******** - -embedding_similarity [func] ---------------------------- - -.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity - :noindex: + ``pytorch_lightning.metrics`` is deprecated from v1.3 and will be removed in v1.5. diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 1c35c69d29f37..95ee66e1d5a14 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -31,7 +31,7 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import Trainer -from pytorch_lightning.metrics.functional import accuracy +from torchmetrics.functional import accuracy from pytorch_lightning.plugins import RPCSequentialPlugin from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 88f4e66605741..4e148a18433a6 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -49,6 +49,7 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from torchmetrics import Accuracy from torchvision import models, transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive @@ -188,8 +189,8 @@ def __init__( self.__build_model() - self.train_acc = pl.metrics.Accuracy() - self.valid_acc = pl.metrics.Accuracy() + self.train_acc = Accuracy() + self.valid_acc = Accuracy() self.save_hyperparameters() def __build_model(self): diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 2c30563b5efbe..500689f3182fb 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from warnings import warn + from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, AUC, @@ -38,3 +40,8 @@ SSIM, ) from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401 + +warn( + "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" + " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5", DeprecationWarning +) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 9d97cbec1a387..343e979dd3e0c 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -38,8 +38,6 @@ class Accuracy(Metric): changed to subset accuracy (which requires all labels or sub-samples in the sample to be correctly predicted) by setting ``subset_accuracy=True``. - Accepts all input types listed in :ref:`extensions/metrics:input types`. - Args: threshold: Threshold probability value for transforming probability predictions to binary @@ -133,8 +131,7 @@ def __init__( def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. + Update state with predictions and targets. Args: preds: Predictions from model (probabilities, or labels) diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 62b4ae824a6d1..adf1086f3c85f 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -35,8 +35,6 @@ class HammingDistance(Metric): treats each possible label separately - meaning that, for example, multi-class data is treated as if it were multi-label. - Accepts all input types listed in :ref:`extensions/metrics:input types`. - Args: threshold: Threshold probability value for transforming probability predictions to binary @@ -92,8 +90,7 @@ def __init__( def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. + Update state with predictions and targets. Args: preds: Predictions from model (probabilities, or labels) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index ea6d5722b3041..58d3142de72f2 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -282,9 +282,7 @@ def _check_classification_inputs( Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. Return: @@ -408,9 +406,7 @@ def _input_format_classification( Should be left unset (``None``) for all other types of inputs. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. Returns: diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 11862769e62a8..5a163097ee0bc 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -31,7 +31,7 @@ class Precision(StatScores): The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -67,11 +67,9 @@ class Precision(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -89,9 +87,7 @@ class Precision(StatScores): Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. @@ -185,7 +181,7 @@ class Recall(StatScores): The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -221,11 +217,9 @@ class Recall(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -244,9 +238,7 @@ class Recall(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 4ac47ea466ada..3807d7079b508 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -28,8 +28,6 @@ class StatScores(Metric): ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the multi-dimensional multi-class case. - Accepts all inputs listed in :ref:`extensions/metrics:input types`. - Args: threshold: Threshold probability value for transforming probability predictions to binary @@ -70,8 +68,7 @@ class StatScores(Metric): Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each @@ -85,9 +82,7 @@ class StatScores(Metric): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. @@ -188,8 +183,7 @@ def __init__( def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. + Update state with predictions and targets. Args: preds: Predictions from model (probabilities or labels) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 05001c8bf5e35..35a5fa85a9763 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -72,8 +72,6 @@ def accuracy( changed to subset accuracy (which requires all labels or sub-samples in the sample to be correctly predicted) by setting ``subset_accuracy=True``. - Accepts all input types listed in :ref:`extensions/metrics:input types`. - Args: preds: Predictions from model (probabilities, or labels) target: Ground truth labels diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 60409751fc9f0..d288a87fc3aaf 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -51,8 +51,6 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float treats each possible label separately - meaning that, for example, multi-class data is treated as if it were multi-label. - Accepts all input types listed in :ref:`extensions/metrics:input types`. - Args: preds: Predictions from model target: Ground truth diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index f0beb564f23ee..09632e216560b 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -60,7 +60,7 @@ def precision( The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities or labels) @@ -91,11 +91,9 @@ def precision( - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -119,9 +117,7 @@ def precision( Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. class_reduction: .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. @@ -234,7 +230,7 @@ def recall( The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities, or labels) @@ -265,11 +261,9 @@ def recall( - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -293,9 +287,7 @@ def recall( Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. class_reduction: .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. @@ -394,7 +386,7 @@ def precision_recall( The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities, or labels) @@ -425,11 +417,9 @@ def precision_recall( - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -453,9 +443,7 @@ def precision_recall( Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. class_reduction: .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index a69e18782c5fe..108cdf7a5b88a 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -153,7 +153,7 @@ def stat_scores( The reduction method (how the statistics are aggregated) is controlled by the ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities or labels) @@ -197,8 +197,7 @@ def stat_scores( Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each @@ -212,9 +211,7 @@ def stat_scores( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. Return: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index f857ad50399cf..901e9c85b162b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -179,9 +179,7 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( - "find_unused_parameters", True - ) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3636b2fb92fa2..d362912fa0185 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -172,9 +172,7 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( - "find_unused_parameters", True - ) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False diff --git a/requirements.txt b/requirements.txt index bdfd6601ba4c2..f196b5e639bf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 +torchmetrics>=0.2.0 diff --git a/requirements/test.txt b/requirements/test.txt index 84ddb2f981b54..099a6fe43b6e6 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,4 @@ -coverage>=5.2 +coverage>5.2.0 codecov>=2.1 pytest>=6.0 pytest-cov>2.10 diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 6962af7249d1b..bd8636ba839f9 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -8,11 +8,13 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize("trainer_kwargs", ( - pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), - pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), - pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), -)) +@pytest.mark.parametrize( + "trainer_kwargs", ( + pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), + pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), + pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), + ) +) def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_master_port() diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index a48f048160ee5..655e12f046e04 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,10 +259,12 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -@pytest.mark.parametrize("trainer_kwargs", ( - {'num_processes': 2}, - pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) -)) +@pytest.mark.parametrize( + "trainer_kwargs", ( + dict(num_processes=2), + pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)), + ) +) def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ Test to ensure we can use validate and test without fit diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 34845c46b45eb..f13448187364c 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -292,7 +292,9 @@ def test_init_optimizers_during_evaluation(tmpdir, fn): """ Test that optimizers is an empty list during evaluation """ + class TestModel(BoringModel): + def configure_optimizers(self): optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e4aea38fb7f37..52c51777e2a89 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -137,6 +137,7 @@ def test_multiple_eval_dataloader(tmpdir, ckpt_path): """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): + def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5b06879b1f6d1..3375b02c5496b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -603,7 +603,9 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test")) def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): self.log("foo", -batch_idx) return super().validation_step(batch, batch_idx)