From 0ba7d632c25b006cd238464818d198651a3a53c6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 26 Jun 2020 11:56:21 +0200 Subject: [PATCH 01/33] new base structure --- pytorch_lightning/metrics/converters.py | 28 +++---- pytorch_lightning/metrics/metric.py | 100 +++++++++++++++++------- 2 files changed, 86 insertions(+), 42 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 9803500445618..330d911d1d708 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -63,7 +63,7 @@ def new_func(*args, **kwargs): return decorator_fn -def _convert_to_tensor(data: Any) -> Any: +def convert_to_tensor(data: Any) -> Any: """ Maps all kind of collections and numbers to tensors. @@ -84,7 +84,7 @@ def _convert_to_tensor(data: Any) -> Any: raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") -def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: +def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: """Convert all tensors and numpy arrays to numpy arrays. Args: @@ -114,7 +114,7 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable: Callable: the decorated function """ return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate) def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -161,7 +161,7 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable: Callable: the decorated function """ return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate) + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate) def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -175,7 +175,7 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C Callable: the decorated function """ return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), - _convert_to_tensor)(func_to_decorate) + convert_to_tensor)(func_to_decorate) def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: @@ -215,10 +215,10 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable return _tensor_collection_metric_output_conversion(func_convert_inputs) -def _sync_ddp_if_available(result: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[torch.distributed.ReduceOp] = None, - ) -> torch.Tensor: +def sync_ddp_if_available(result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Any] = None, + ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -247,7 +247,7 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], def sync_ddp(group: Optional[Any] = None, - reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + reduce_op: Optional[Any] = None) -> Callable: """ This decorator syncs a functions outputs across different processes for DDP. @@ -262,14 +262,14 @@ def sync_ddp(group: Optional[Any] = None, def decorator_fn(func_to_decorate): return _apply_to_outputs(apply_to_collection, torch.Tensor, - _sync_ddp_if_available, group=group, + sync_ddp_if_available, group=group, reduce_op=reduce_op)(func_to_decorate) return decorator_fn def numpy_metric(group: Optional[Any] = None, - reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + reduce_op: Optional[Any] = None) -> Callable: """ This decorator shall be used on all function metrics working on numpy arrays. It handles the argument conversion and DDP reduction for metrics working on numpy. @@ -292,7 +292,7 @@ def decorator_fn(func_to_decorate): def tensor_metric(group: Optional[Any] = None, - reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + reduce_op: Optional[Any] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors. It handles the argument conversion and DDP reduction for metrics working on tensors. @@ -314,7 +314,7 @@ def decorator_fn(func_to_decorate): def tensor_collection_metric(group: Optional[Any] = None, - reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + reduce_op: Optional[Any] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors and returning collections that cannot be converted to tensors. diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 349a6ecfa2f82..074e93274d922 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -5,7 +5,8 @@ import torch.distributed from pytorch_lightning.metrics.converters import ( - tensor_metric, numpy_metric, tensor_collection_metric) + tensor_metric, numpy_metric, tensor_collection_metric, + sync_ddp_if_available, convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -29,19 +30,37 @@ def __init__(self, name: str): self.name = name self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') - + self.register_forward_pre_hook(self.input_convert) + self.register_forward_hook(self.ddp_sync) + self.register_forward_hook(self.compute) + self.register_forward_hook(self.output_convert) + @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: """ - Implements the actual metric computation. + Implements the actual metric computation. Returns: - metric value + metric value or metric state """ raise NotImplementedError - - + + def compute(self, module, input, output) -> torch.Tensor: + """ + Output contains the + """ + return output + + def ddp_sync(self, module, input, output): + return output + + def input_convert(self, module, input): + return input + + def output_convert(self, module, input, output): + return output + class TensorMetric(Metric): """ Base class for metric implementation operating directly on tensors. @@ -62,15 +81,20 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = tensor_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) - - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) - - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + + def input_convert(self, module, input): + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) + + def ddp_sync(self, module, input, output): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) + + def output_convert(self, module, input, output): + return apply_to_collection(output, torch.Tensor, convert_to_tensor) class TensorCollectionMetric(Metric): @@ -103,16 +127,23 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = tensor_collection_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) + self.reduce_group = reduce_group + self.reduce_op = reduce_op - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) + def input_convert(self, module, input): + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + def ddp_sync(self, module, input, output): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) + + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) class NumpyMetric(Metric): """ @@ -138,9 +169,22 @@ def __init__(self, name: str, self._orig_call = numpy_metric(group=reduce_group, reduce_op=reduce_op)(super().__call__) - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) - - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + def input_convert(self, module, input): + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_numpy) + + def ddp_sync(self, module, input, output): + # For numpy we need to convert the output of forward before ddp sync + output = apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) + + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) + + \ No newline at end of file From 6481b6b43f1f59a3c74a2d4d54fba7ecff1682b3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 26 Jun 2020 12:02:03 +0200 Subject: [PATCH 02/33] missing packages --- pytorch_lightning/metrics/metric.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 074e93274d922..c6cf1b851333a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod +import numbers from typing import Any, Optional +import numpy as np import torch import torch.distributed From f6a0a4d7bd5c04365845e4b2d2948d24fc800d7e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 30 Jun 2020 09:55:24 +0200 Subject: [PATCH 03/33] updated interface --- pytorch_lightning/metrics/classification.py | 10 +-- pytorch_lightning/metrics/converters.py | 16 ++-- .../metrics/functional/classification.py | 2 +- pytorch_lightning/metrics/metric.py | 89 +++++-------------- tests/metrics/test_classification.py | 1 - tests/metrics/test_converters.py | 14 +-- tests/metrics/test_metrics.py | 60 +------------ tests/metrics/test_sklearn.py | 8 +- 8 files changed, 50 insertions(+), 150 deletions(-) diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index aad3c60183400..3b7ffd8504612 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -18,7 +18,7 @@ dice_score, iou, ) -from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric +from pytorch_lightning.metrics.metric import TensorMetric class Accuracy(TensorMetric): @@ -123,7 +123,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: normalize=self.normalize) -class PrecisionRecall(TensorCollectionMetric): +class PrecisionRecall(TensorMetric): """ Computes the precision recall curve @@ -515,7 +515,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduction=self.reduction) -class ROC(TensorCollectionMetric): +class ROC(TensorMetric): """ Computes the Receiver Operator Characteristic (ROC) @@ -576,7 +576,7 @@ def forward( pos_label=self.pos_label) -class MulticlassROC(TensorCollectionMetric): +class MulticlassROC(TensorMetric): """ Computes the multiclass ROC @@ -642,7 +642,7 @@ def forward( num_classes=self.num_classes) -class MulticlassPrecisionRecall(TensorCollectionMetric): +class MulticlassPrecisionRecall(TensorMetric): """Computes the multiclass PR Curve Example: diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 330d911d1d708..e8d68f6d7d531 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -63,7 +63,7 @@ def new_func(*args, **kwargs): return decorator_fn -def convert_to_tensor(data: Any) -> Any: +def convert_to_tensor(data: Any, dtype=None, device=None) -> Any: """ Maps all kind of collections and numbers to tensors. @@ -74,12 +74,12 @@ def convert_to_tensor(data: Any) -> Any: the converted data """ if isinstance(data, numbers.Number): - return torch.tensor([data]) + return torch.tensor([data], dtype=dtype, device=device) # is not array of object elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: - return torch.from_numpy(data) + return torch.from_numpy(data).to(device=device, dtype=dtype) elif isinstance(data, torch.Tensor): - return data + return data.to(device=device, dtype=dtype) raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") @@ -127,7 +127,7 @@ def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: Return: Callable: the decorated function """ - return _apply_to_outputs(_convert_to_tensor)(func_to_decorate) + return _apply_to_outputs(convert_to_tensor)(func_to_decorate) def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: @@ -218,6 +218,7 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable def sync_ddp_if_available(result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Any] = None, + ddp_normalize = False, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -242,7 +243,10 @@ def sync_ddp_if_available(result: Union[torch.Tensor], torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) - + + if ddp_normalize: + result / torch.distributed.get_world_size(group) + return result diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8b392e5a117e3..7580db8b5a8ea 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -256,7 +256,7 @@ def confusion_matrix( """ num_classes = get_num_classes(pred, target, None) - unique_labels = target.view(-1) * num_classes + pred.view(-1) + unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int) bins = torch.bincount(unique_labels, minlength=num_classes ** 2) cm = bins.reshape(num_classes, num_classes).squeeze().float() diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index c6cf1b851333a..a58dad6e1be8a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -7,7 +7,6 @@ import torch.distributed from pytorch_lightning.metrics.converters import ( - tensor_metric, numpy_metric, tensor_collection_metric, sync_ddp_if_available, convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -33,9 +32,9 @@ def __init__(self, name: str): self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') self.register_forward_pre_hook(self.input_convert) + self.register_forward_hook(self.output_convert) self.register_forward_hook(self.ddp_sync) self.register_forward_hook(self.compute) - self.register_forward_hook(self.output_convert) @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: @@ -72,7 +71,8 @@ class TensorMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): + reduce_op: Optional[Any] = None, + ddp_normalize: bool = False): """ Args: @@ -85,67 +85,22 @@ def __init__(self, name: str, super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op + self.ddp_normalize = ddp_normalize def input_convert(self, module, input): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) - - def ddp_sync(self, module, input, output): - return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) - - def output_convert(self, module, input, output): - return apply_to_collection(output, torch.Tensor, convert_to_tensor) - - -class TensorCollectionMetric(Metric): - """ - Base class for metric implementation operating directly on tensors. - All inputs will be casted to tensors if necessary. Outputs won't be casted. - Already handles DDP sync and input conversions. - - This class differs from :class:`TensorMetric`, as it assumes all outputs to - be collections of tensors and does not explicitly convert them. This is - necessary, since some collections (like for ROC, Precision-Recall Curve etc.) - cannot be converted to tensors at the highest level. - All numpy arrays and numbers occuring in these outputs will still be converted. - - Use this class as a baseclass, whenever you want to ensure inputs are - tensors and outputs cannot be converted to tensors automatically - - """ - - def __init__(self, name: str, - reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): - """ - - Args: - name: the metric's name - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. - """ - super().__init__(name) - self.reduce_group = reduce_group - self.reduce_op = reduce_op + convert_to_tensor, self.dtype, self.device) - def input_convert(self, module, input): - return apply_to_collection(input, + def output_convert(self, module, input, output): + return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) - + convert_to_tensor, self.dtype, self.device) + def ddp_sync(self, module, input, output): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) + self.reduce_group, self.reduce_op, self.ddp_normalize) - - def output_convert(self, module, input, output): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) class NumpyMetric(Metric): """ @@ -157,7 +112,8 @@ class NumpyMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): + reduce_op: Optional[Any] = None, + ddp_normalize: bool = False): """ Args: @@ -168,25 +124,24 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = numpy_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + self.ddp_normalize = ddp_normalize def input_convert(self, module, input): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) + + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, self.dtype, self.device) + def ddp_sync(self, module, input, output): - # For numpy we need to convert the output of forward before ddp sync - output = apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) + self.reduce_group, self.reduce_op, self.ddp_normalize) - def output_convert(self, module, input, output): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) \ No newline at end of file diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py index 52fffd4ad72d1..53ec95c907092 100644 --- a/tests/metrics/test_classification.py +++ b/tests/metrics/test_classification.py @@ -45,7 +45,6 @@ def test_confusion_matrix(normalize): target = (torch.arange(120) % 3).view(-1, 1) pred = target.clone() - cm = conf_matrix(pred, target) assert isinstance(cm, torch.Tensor) diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py index b6c102b1fd83b..801dc0feac125 100644 --- a/tests/metrics/test_converters.py +++ b/tests/metrics/test_converters.py @@ -8,11 +8,11 @@ from pytorch_lightning.metrics.converters import ( _apply_to_inputs, _apply_to_outputs, - _convert_to_tensor, - _convert_to_numpy, + convert_to_tensor, + convert_to_numpy, _numpy_metric_conversion, _tensor_metric_conversion, - _sync_ddp_if_available, + sync_ddp_if_available, tensor_metric, numpy_metric ) @@ -61,14 +61,14 @@ def test_fn(*args, **kwargs): def test_convert_to_tensor(): for test_item in [1., np.array([1.])]: - result_tensor = _convert_to_tensor(test_item) + result_tensor = convert_to_tensor(test_item) assert isinstance(result_tensor, torch.Tensor) assert result_tensor.item() == 1. def test_convert_to_numpy(): for test_item in [1., torch.tensor([1.])]: - result = _convert_to_numpy(test_item) + result = convert_to_numpy(test_item) assert isinstance(result, np.ndarray) assert result.item() == 1. @@ -118,7 +118,7 @@ def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.], device='cuda:0') - reduced_tensor = _sync_ddp_if_available(tensor) + reduced_tensor = sync_ddp_if_available(tensor) assert reduced_tensor.item() == dist.get_world_size(), \ 'Sync-Reduce does not work properly with DDP and Tensors' @@ -141,7 +141,7 @@ def test_sync_reduce_simple(): """Make sure sync-reduce works without DDP""" tensor = torch.tensor([1.], device='cpu') - reduced_tensor = _sync_ddp_if_available(tensor) + reduced_tensor = sync_ddp_if_available(tensor) assert torch.allclose(tensor, reduced_tensor), \ 'Sync-Reduce does not work properly without DDP and Tensors' diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 2176c94dfb925..ada9e8cbd2c68 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -2,7 +2,7 @@ import pytest import torch -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric class DummyTensorMetric(TensorMetric): @@ -25,64 +25,6 @@ def forward(self, input1, input2): return 1. -class DummyTensorCollectionMetric(TensorCollectionMetric): - def __init__(self): - super().__init__('dummy') - - def forward(self, input1, input2): - assert isinstance(input1, torch.Tensor) - assert isinstance(input2, torch.Tensor) - return 1., 2., 3., 4. - - -@pytest.mark.parametrize('metric', [DummyTensorCollectionMetric()]) -def test_collection_metric(metric: Metric): - """ Test that metric.device, metric.dtype works for metric collection """ - input1, input2 = torch.tensor([1.]), torch.tensor([2.]) - - def change_and_check_device_dtype(device, dtype): - metric.to(device=device, dtype=dtype) - - metric_val = metric(input1, input2) - assert not isinstance(metric_val, torch.Tensor) - - if device is not None: - assert metric.device in [device, torch.device(device)] - - if dtype is not None: - assert metric.dtype == dtype - - devices = [None, 'cpu'] - if torch.cuda.is_available(): - devices += ['cuda:0'] - - for device in devices: - for dtype in [None, torch.float32, torch.float64]: - change_and_check_device_dtype(device=device, dtype=dtype) - - if torch.cuda.is_available(): - metric.cuda(0) - assert metric.device == torch.device('cuda', index=0) - - metric.cpu() - assert metric.device == torch.device('cpu') - - metric.type(torch.int8) - assert metric.dtype == torch.int8 - - metric.float() - assert metric.dtype == torch.float32 - - metric.double() - assert metric.dtype == torch.float64 - assert all(out.dtype == torch.float64 for out in metric(input1, input2)) - - if torch.cuda.is_available(): - metric.cuda() - metric.half() - assert metric.dtype == torch.float16 - - @pytest.mark.parametrize('metric', [ DummyTensorMetric(), DummyNumpyMetric(), diff --git a/tests/metrics/test_sklearn.py b/tests/metrics/test_sklearn.py index 335d4dc767b42..7243ca35d6311 100644 --- a/tests/metrics/test_sklearn.py +++ b/tests/metrics/test_sklearn.py @@ -19,7 +19,7 @@ roc_auc_score as sk_roc_auc_score, ) -from pytorch_lightning.metrics.converters import _convert_to_numpy +from pytorch_lightning.metrics.converters import convert_to_numpy from pytorch_lightning.metrics.sklearns import ( Accuracy, AveragePrecision, @@ -79,17 +79,17 @@ def new_func(*args, **kwargs): id='AUROC'), ]) def test_sklearn_metric(metric_class, sklearn_func, inputs): - numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) sklearn_result = sklearn_func(**numpy_inputs) lightning_result = metric_class(**inputs) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) sklearn_result = apply_to_collection( - sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) lightning_result = apply_to_collection( - lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) assert isinstance(lightning_result, type(sklearn_result)) From 9368ac66e8ce69a9f058f31e3144b055a5589eee Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 6 Jul 2020 12:40:22 +0200 Subject: [PATCH 04/33] revert some changes --- pytorch_lightning/metrics/classification.py | 10 +-- pytorch_lightning/metrics/metric.py | 91 +++++++++++++++------ tests/metrics/test_metrics.py | 60 +++++++++++++- 3 files changed, 131 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index 3b7ffd8504612..aad3c60183400 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -18,7 +18,7 @@ dice_score, iou, ) -from pytorch_lightning.metrics.metric import TensorMetric +from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric class Accuracy(TensorMetric): @@ -123,7 +123,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: normalize=self.normalize) -class PrecisionRecall(TensorMetric): +class PrecisionRecall(TensorCollectionMetric): """ Computes the precision recall curve @@ -515,7 +515,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduction=self.reduction) -class ROC(TensorMetric): +class ROC(TensorCollectionMetric): """ Computes the Receiver Operator Characteristic (ROC) @@ -576,7 +576,7 @@ def forward( pos_label=self.pos_label) -class MulticlassROC(TensorMetric): +class MulticlassROC(TensorCollectionMetric): """ Computes the multiclass ROC @@ -642,7 +642,7 @@ def forward( num_classes=self.num_classes) -class MulticlassPrecisionRecall(TensorMetric): +class MulticlassPrecisionRecall(TensorCollectionMetric): """Computes the multiclass PR Curve Example: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index a58dad6e1be8a..074e93274d922 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -import numbers from typing import Any, Optional -import numpy as np import torch import torch.distributed from pytorch_lightning.metrics.converters import ( + tensor_metric, numpy_metric, tensor_collection_metric, sync_ddp_if_available, convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -32,9 +31,9 @@ def __init__(self, name: str): self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') self.register_forward_pre_hook(self.input_convert) - self.register_forward_hook(self.output_convert) self.register_forward_hook(self.ddp_sync) self.register_forward_hook(self.compute) + self.register_forward_hook(self.output_convert) @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: @@ -71,8 +70,7 @@ class TensorMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None, - ddp_normalize: bool = False): + reduce_op: Optional[Any] = None): """ Args: @@ -85,35 +83,40 @@ def __init__(self, name: str, super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op - self.ddp_normalize = ddp_normalize def input_convert(self, module, input): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, self.dtype, self.device) - - def output_convert(self, module, input, output): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, self.dtype, self.device) + convert_to_tensor) def ddp_sync(self, module, input, output): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op, self.ddp_normalize) + self.reduce_group, self.reduce_op) + + def output_convert(self, module, input, output): + return apply_to_collection(output, torch.Tensor, convert_to_tensor) -class NumpyMetric(Metric): +class TensorCollectionMetric(Metric): """ - Base class for metric implementation operating on numpy arrays. - All inputs will be casted to numpy if necessary and all outputs will - be casted to tensors if necessary. - Already handles DDP sync and input/output conversions. + Base class for metric implementation operating directly on tensors. + All inputs will be casted to tensors if necessary. Outputs won't be casted. + Already handles DDP sync and input conversions. + + This class differs from :class:`TensorMetric`, as it assumes all outputs to + be collections of tensors and does not explicitly convert them. This is + necessary, since some collections (like for ROC, Precision-Recall Curve etc.) + cannot be converted to tensors at the highest level. + All numpy arrays and numbers occuring in these outputs will still be converted. + + Use this class as a baseclass, whenever you want to ensure inputs are + tensors and outputs cannot be converted to tensors automatically + """ def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None, - ddp_normalize: bool = False): + reduce_op: Optional[Any] = None): """ Args: @@ -126,22 +129,62 @@ def __init__(self, name: str, super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op - self.ddp_normalize = ddp_normalize def input_convert(self, module, input): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_numpy) + convert_to_tensor) + + def ddp_sync(self, module, input, output): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) + def output_convert(self, module, input, output): return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, self.dtype, self.device) + convert_to_tensor) + +class NumpyMetric(Metric): + """ + Base class for metric implementation operating on numpy arrays. + All inputs will be casted to numpy if necessary and all outputs will + be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = None, + reduce_op: Optional[Any] = None): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = numpy_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def input_convert(self, module, input): + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_numpy) def ddp_sync(self, module, input, output): + # For numpy we need to convert the output of forward before ddp sync + output = apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op, self.ddp_normalize) + self.reduce_group, self.reduce_op) + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor) \ No newline at end of file diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index ada9e8cbd2c68..2176c94dfb925 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -2,7 +2,7 @@ import pytest import torch -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric class DummyTensorMetric(TensorMetric): @@ -25,6 +25,64 @@ def forward(self, input1, input2): return 1. +class DummyTensorCollectionMetric(TensorCollectionMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, torch.Tensor) + assert isinstance(input2, torch.Tensor) + return 1., 2., 3., 4. + + +@pytest.mark.parametrize('metric', [DummyTensorCollectionMetric()]) +def test_collection_metric(metric: Metric): + """ Test that metric.device, metric.dtype works for metric collection """ + input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + + def change_and_check_device_dtype(device, dtype): + metric.to(device=device, dtype=dtype) + + metric_val = metric(input1, input2) + assert not isinstance(metric_val, torch.Tensor) + + if device is not None: + assert metric.device in [device, torch.device(device)] + + if dtype is not None: + assert metric.dtype == dtype + + devices = [None, 'cpu'] + if torch.cuda.is_available(): + devices += ['cuda:0'] + + for device in devices: + for dtype in [None, torch.float32, torch.float64]: + change_and_check_device_dtype(device=device, dtype=dtype) + + if torch.cuda.is_available(): + metric.cuda(0) + assert metric.device == torch.device('cuda', index=0) + + metric.cpu() + assert metric.device == torch.device('cpu') + + metric.type(torch.int8) + assert metric.dtype == torch.int8 + + metric.float() + assert metric.dtype == torch.float32 + + metric.double() + assert metric.dtype == torch.float64 + assert all(out.dtype == torch.float64 for out in metric(input1, input2)) + + if torch.cuda.is_available(): + metric.cuda() + metric.half() + assert metric.dtype == torch.float16 + + @pytest.mark.parametrize('metric', [ DummyTensorMetric(), DummyNumpyMetric(), From 5d7652853b2b9ba0c052458b9e64bb5b41b585e3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 6 Jul 2020 13:09:05 +0200 Subject: [PATCH 05/33] fixes --- pytorch_lightning/metrics/converters.py | 10 +- pytorch_lightning/metrics/metric.py | 127 +++++++++++++++--------- tests/metrics/test_metrics.py | 2 +- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index e8d68f6d7d531..4713151f71f8f 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -70,6 +70,10 @@ def convert_to_tensor(data: Any, dtype=None, device=None) -> Any: Args: data: the data to convert to tensor + dtype: data type to convert to + + device: device to cast to + Return: the converted data """ @@ -218,7 +222,7 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable def sync_ddp_if_available(result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Any] = None, - ddp_normalize = False, + ddp_normalize=False, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -243,10 +247,10 @@ def sync_ddp_if_available(result: Union[torch.Tensor], torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) - + if ddp_normalize: result / torch.distributed.get_world_size(group) - + return result diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 074e93274d922..922c1a6c4fd94 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional +import numbers import torch import torch.distributed +import numpy as np from pytorch_lightning.metrics.converters import ( - tensor_metric, numpy_metric, tensor_collection_metric, sync_ddp_if_available, convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -18,6 +19,16 @@ class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): Should be used to implement metrics that 1. Return multiple Outputs 2. Handle their own DDP sync + + Metric hooks that can be implemented are: + input_convert: pre-forward hook that takes care of input conversion + output_convert: post-forward hook that takes care of output convertion + ddp_sync: implementation of ddp sync + compute: post-ddp sync for additional metric computations + + Call order: + input_convert -> forward -> output_convert -> ddp_sync -> compute + """ def __init__(self, name: str): @@ -31,36 +42,51 @@ def __init__(self, name: str): self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') self.register_forward_pre_hook(self.input_convert) + self.register_forward_hook(self.output_convert) self.register_forward_hook(self.ddp_sync) self.register_forward_hook(self.compute) - self.register_forward_hook(self.output_convert) - + @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: """ - Implements the actual metric computation. + Implements the actual metric computation. Returns: metric value or metric state """ raise NotImplementedError - + def compute(self, module, input, output) -> torch.Tensor: """ - Output contains the + Implement additionally metric computations to be done after the ddp sync + + Args: + module: current metric module + + input: input to forward method + + output: output from forward method + + Returns: + final metric value + """ return output - + def ddp_sync(self, module, input, output): + """ + + """ return output - + def input_convert(self, module, input): return input - + def output_convert(self, module, input, output): return output - + + class TensorMetric(Metric): """ Base class for metric implementation operating directly on tensors. @@ -70,7 +96,8 @@ class TensorMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): + reduce_op: Optional[Any] = None, + ddp_normalize: bool = False): """ Args: @@ -83,18 +110,21 @@ def __init__(self, name: str, super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op - + self.ddp_normalize = ddp_normalize + def input_convert(self, module, input): - return apply_to_collection(input, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) - + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) + + def output_convert(self, module, input, output): + return apply_to_collection(output, torch.Tensor, convert_to_tensor, + self.dtype, self.device) + def ddp_sync(self, module, input, output): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) - - def output_convert(self, module, input, output): - return apply_to_collection(output, torch.Tensor, convert_to_tensor) + self.reduce_group, self.reduce_op, self.ddp_normalize) class TensorCollectionMetric(Metric): @@ -116,7 +146,8 @@ class TensorCollectionMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): + reduce_op: Optional[Any] = None, + ddp_normalize: bool = False): """ Args: @@ -129,21 +160,24 @@ def __init__(self, name: str, super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op + self.ddp_normalize = ddp_normalize def input_convert(self, module, input): - return apply_to_collection(input, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) + + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) def ddp_sync(self, module, input, output): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) + self.reduce_group, self.reduce_op, self.ddp_normalize) - - def output_convert(self, module, input, output): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) class NumpyMetric(Metric): """ @@ -155,7 +189,8 @@ class NumpyMetric(Metric): def __init__(self, name: str, reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): + reduce_op: Optional[Any] = None, + ddp_normalize: bool = False): """ Args: @@ -166,25 +201,21 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = numpy_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + self.ddp_normalize = ddp_normalize def input_convert(self, module, input): - return apply_to_collection(input, - (torch.Tensor, np.ndarray, numbers.Number), + return apply_to_collection(input, + (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - + + def output_convert(self, module, input, output): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) + def ddp_sync(self, module, input, output): - # For numpy we need to convert the output of forward before ddp sync - output = apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) - - def output_convert(self, module, input, output): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor) - - \ No newline at end of file + self.reduce_group, self.reduce_op, self.ddp_normalize) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 2176c94dfb925..d6a82a22af8c9 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -12,7 +12,7 @@ def __init__(self): def forward(self, input1, input2): assert isinstance(input1, torch.Tensor) assert isinstance(input2, torch.Tensor) - return 1. + return torch.tensor([1.]) class DummyNumpyMetric(NumpyMetric): From 4958cb288edba809b7c9c4996fab3eddc2978e08 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 6 Jul 2020 16:46:26 +0200 Subject: [PATCH 06/33] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb9fdfbc5ba06..4d257ae90c17d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) ### Changed From 0a3849ee26da05ba7b85b9602159f8bf475a8811 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 6 Jul 2020 16:47:01 +0200 Subject: [PATCH 07/33] fix bug --- pytorch_lightning/metrics/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 3a6d6db924fda..cb3189260f468 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -249,7 +249,7 @@ def sync_ddp_if_available(result: Union[torch.Tensor], async_op=False) if ddp_normalize: - result / torch.distributed.get_world_size(group) + result = result / torch.distributed.get_world_size(group) return result From 883efa94221ef0448055d976081e35b1c209a35d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 6 Jul 2020 16:57:21 +0200 Subject: [PATCH 08/33] added description --- pytorch_lightning/metrics/metric.py | 70 ++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 922c1a6c4fd94..553ba2941bf34 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -3,7 +3,7 @@ import numbers import torch -import torch.distributed +from torch import nn import numpy as np from pytorch_lightning.metrics.converters import ( @@ -12,7 +12,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): +class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): """ Abstract base class for metric implementation. @@ -47,7 +47,7 @@ def __init__(self, name: str): self.register_forward_hook(self.compute) @abstractmethod - def forward(self, *args, **kwargs) -> torch.Tensor: + def forward(self, *args, **kwargs): """ Implements the actual metric computation. @@ -57,7 +57,7 @@ def forward(self, *args, **kwargs) -> torch.Tensor: """ raise NotImplementedError - def compute(self, module, input, output) -> torch.Tensor: + def compute(self, module: nn.Module, input: Any, output: Any): """ Implement additionally metric computations to be done after the ddp sync @@ -74,16 +74,51 @@ def compute(self, module, input, output) -> torch.Tensor: """ return output - def ddp_sync(self, module, input, output): + def ddp_sync(self, module: nn.Module, input: Any, output: Any): """ + Implement how the outputs from forward should be synced + + Args: + module: current metric module + + input: input to forward method + + output: output from forward method + + Returns: + synced output """ return output - def input_convert(self, module, input): + def input_convert(self, module: nn.Module, input: Any): + """ + Implement how the inputs should be casted before calling forward + + Args: + module: current metric module + + input: input to forward method + + Returns: + casted input + """ return input - def output_convert(self, module, input, output): + def output_convert(self, module: nn.Module, input: Any, output: Any): + """ + Implement how outputs from forward should be casted + + Args: + module: current metric module + + input: input to forward method + + output: output from forward method + + Returns: + casted outputs + """ return output @@ -106,23 +141,24 @@ def __init__(self, name: str, Defaults to all processes (world) reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). Defaults to sum. + ddp_normalize: if true, will divide the DDP reduce result by the world rank """ super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op self.ddp_normalize = ddp_normalize - def input_convert(self, module, input): + def input_convert(self, module: nn.Module, input: Any): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def output_convert(self, module, input, output): + def output_convert(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, torch.Tensor, convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module, input, output): + def ddp_sync(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op, self.ddp_normalize) @@ -156,25 +192,26 @@ def __init__(self, name: str, Defaults to all processes (world) reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). Defaults to sum. + ddp_normalize: if true, will divide the DDP reduce result by the world rank """ super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op self.ddp_normalize = ddp_normalize - def input_convert(self, module, input): + def input_convert(self, module: nn.Module, input: Any): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def output_convert(self, module, input, output): + def output_convert(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module, input, output): + def ddp_sync(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op, self.ddp_normalize) @@ -199,23 +236,24 @@ def __init__(self, name: str, Defaults to all processes (world) reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). Defaults to sum. + ddp_normalize: if true, will divide the DDP reduce result by the world rank """ super().__init__(name) self.reduce_group = reduce_group self.reduce_op = reduce_op self.ddp_normalize = ddp_normalize - def input_convert(self, module, input): + def input_convert(self, module: nn.Module, input: Any): return apply_to_collection(input, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - def output_convert(self, module, input, output): + def output_convert(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module, input, output): + def ddp_sync(self, module: nn.Module, input: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op, self.ddp_normalize) From d99821ec7562084092de3f9343547ea0fbc6096a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 7 Aug 2020 13:49:35 +0200 Subject: [PATCH 09/33] test for pickable --- pytorch_lightning/core/step_result.py | 4 ++-- tests/metrics/test_metrics.py | 29 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 172930fd4ad9a..f4e9767eeb5f0 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -3,7 +3,7 @@ from torch import Tensor import torch from copy import copy -from pytorch_lightning.metrics.converters import _sync_ddp_if_available +from pytorch_lightning.metrics.converters import sync_ddp_if_available class Result(Dict): @@ -101,7 +101,7 @@ def log( # sync across ddp if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)): - value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op) + value = sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op) if 'meta' not in self: self.__setitem__('meta', {}) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index d6a82a22af8c9..9dc4d6ebc24f3 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -2,6 +2,9 @@ import pytest import torch +import tests.base.develop_utils as tutils +import tests.base.develop_pipelines as tpipes +from tests.base import EvalModelTemplate from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric @@ -139,3 +142,29 @@ def change_and_check_device_dtype(device, dtype): metric.half() assert metric.dtype == torch.float16 assert metric(input1, input2).dtype == torch.float16 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.parametrize("metric", [ + DummyTensorMetric(), + DummyNumpyMetric(), +]) +def test_pickable(tmpdir, metric: Metric): + """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + + class ModelwithMetrics(EvalModelTemplate): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.metric = metric + + model = ModelwithMetrics() + tpipes.run_model_test(trainer_options, model) From e8a6a7b264881d1fec983f6803226c7ed671a46a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 7 Aug 2020 14:53:33 +0200 Subject: [PATCH 10/33] fixing test --- tests/metrics/test_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 9dc4d6ebc24f3..465c678d21bb3 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -144,7 +144,7 @@ def change_and_check_device_dtype(device, dtype): assert metric(input1, input2).dtype == torch.float16 -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +#@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.parametrize("metric", [ DummyTensorMetric(), DummyNumpyMetric(), @@ -157,14 +157,14 @@ def test_pickable(tmpdir, metric: Metric): default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, - gpus=[0, 1], - distributed_backend='ddp_spawn', + gpus=[0]#, 1], + #distributed_backend='ddp_spawn', ) class ModelwithMetrics(EvalModelTemplate): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.metric = metric + #self.metric = metric model = ModelwithMetrics() tpipes.run_model_test(trainer_options, model) From 9626b22b0c0fb2c97b1db5a9abb4941897db9aee Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 7 Aug 2020 14:55:55 +0200 Subject: [PATCH 11/33] fixing test --- tests/metrics/test_metrics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 465c678d21bb3..3d77fcc9c7996 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -144,7 +144,7 @@ def change_and_check_device_dtype(device, dtype): assert metric(input1, input2).dtype == torch.float16 -#@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.parametrize("metric", [ DummyTensorMetric(), DummyNumpyMetric(), @@ -157,14 +157,13 @@ def test_pickable(tmpdir, metric: Metric): default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, - gpus=[0]#, 1], - #distributed_backend='ddp_spawn', + gpus=[0, 1], + distributed_backend='ddp_spawn', ) class ModelwithMetrics(EvalModelTemplate): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - #self.metric = metric model = ModelwithMetrics() tpipes.run_model_test(trainer_options, model) From 1f458f668416fe106e11c449936a6e8c1338fe75 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 7 Aug 2020 15:49:49 +0200 Subject: [PATCH 12/33] fix pickle issue --- tests/base/model_train_steps.py | 18 ++++++++++++++++++ tests/metrics/test_metrics.py | 9 ++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index af9f662508eec..69c022bfb6a4a 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -133,3 +133,21 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result): setattr(result, f'{eval_name}_step_end_metric', reduced) return result + + def training_step_using_metrics(self, batch, batch_idx, optimizer_idx=None): + """Lightning calls this inside the training loop""" + # forward pass + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # call metric + val = self.metric(x, y) + + output = OrderedDict({'loss': loss_val, + 'progress_bar': {'metric_val': val}, + 'log': {'metric_val': val}}) + return output diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 3d77fcc9c7996..bf5966fa4db4a 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -146,7 +146,7 @@ def change_and_check_device_dtype(device, dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.parametrize("metric", [ - DummyTensorMetric(), + DummyTensorMetric, DummyNumpyMetric(), ]) def test_pickable(tmpdir, metric: Metric): @@ -161,9 +161,8 @@ def test_pickable(tmpdir, metric: Metric): distributed_backend='ddp_spawn', ) - class ModelwithMetrics(EvalModelTemplate): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + model = EvalModelTemplate() + model.metric = metric() + model.training_step = model.training_step_using_metrics - model = ModelwithMetrics() tpipes.run_model_test(trainer_options, model) From be70ac8f59c31a12ba2b5810a48ff8a7091fbff0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 10 Aug 2020 10:01:27 +0200 Subject: [PATCH 13/33] reduceop typehints back --- pytorch_lightning/metrics/converters.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 6bbfc6d8a6943..8a0a0258f0884 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -14,6 +14,14 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +try: + from torch.distributed import ReduceOp +except ImportError: + class ReduceOp: + SUM = None + + rank_zero_warn('Unsupported `ReduceOp` for distributed computing') + def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: """ @@ -222,7 +230,8 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable def sync_ddp_if_available(result: Union[torch.Tensor], group: Optional[Any] = None, - reduce_op: Optional[Any] = None) -> torch.Tensor: + reduce_op: Optional[ReduceOp] = None + ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -260,7 +269,7 @@ def sync_ddp_if_available(result: Union[torch.Tensor], def sync_ddp(group: Optional[Any] = None, - reduce_op: Optional[Any] = None) -> Callable: + reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator syncs a functions outputs across different processes for DDP. @@ -282,7 +291,7 @@ def decorator_fn(func_to_decorate): def numpy_metric(group: Optional[Any] = None, - reduce_op: Optional[Any] = None) -> Callable: + reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on numpy arrays. It handles the argument conversion and DDP reduction for metrics working on numpy. @@ -305,7 +314,7 @@ def decorator_fn(func_to_decorate): def tensor_metric(group: Optional[Any] = None, - reduce_op: Optional[Any] = None) -> Callable: + reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors. It handles the argument conversion and DDP reduction for metrics working on tensors. @@ -327,7 +336,7 @@ def decorator_fn(func_to_decorate): def tensor_collection_metric(group: Optional[Any] = None, - reduce_op: Optional[Any] = None) -> Callable: + reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors and returning collections that cannot be converted to tensors. From fd6e719c845872ea24df4f4ed99d64116cc529d1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 11 Aug 2020 13:30:29 +0200 Subject: [PATCH 14/33] remove redundant module arg --- pytorch_lightning/metrics/metric.py | 43 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f70dedb6ce14c..4dfc356ae7580 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -59,12 +59,12 @@ def forward(self, *args, **kwargs): """ raise NotImplementedError - def compute(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def compute(self, data: Any, output: Any): """ Implement additionally metric computations to be done after the ddp sync Args: - module: current metric module data: input to forward method @@ -76,12 +76,12 @@ def compute(self, module: nn.Module, data: Any, output: Any): """ return output - def ddp_sync(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def ddp_sync(self, data: Any, output: Any): """ Implement how the outputs from forward should be synced Args: - module: current metric module data: input to forward method @@ -93,12 +93,12 @@ def ddp_sync(self, module: nn.Module, data: Any, output: Any): """ return output - def input_convert(self, module: nn.Module, data: Any): + @staticmethod + def input_convert(self, data: Any): """ Implement how the inputs should be casted before calling forward Args: - module: current metric module data: input to forward method @@ -107,12 +107,12 @@ def input_convert(self, module: nn.Module, data: Any): """ return data - def output_convert(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def output_convert(self, data: Any, output: Any): """ Implement how outputs from forward should be casted Args: - module: current metric module data: input to forward method @@ -147,17 +147,20 @@ def __init__(self, name: str, self.reduce_group = reduce_group self.reduce_op = reduce_op - def input_convert(self, module: nn.Module, data: Any): + @staticmethod + def input_convert(self, data: Any): return apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def output_convert(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def output_convert(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) @@ -195,19 +198,22 @@ def __init__(self, name: str, self.reduce_group = reduce_group self.reduce_op = reduce_op - def input_convert(self, module: nn.Module, data: Any): + @staticmethod + def input_convert(self, data: Any): return apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def output_convert(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def output_convert(self, data: Any, output: Any): return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) @@ -236,17 +242,20 @@ def __init__(self, name: str, self.reduce_group = reduce_group self.reduce_op = reduce_op - def input_convert(self, module: nn.Module, data: Any): + @staticmethod + def input_convert(self, data: Any): return apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - def output_convert(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def output_convert(self, data: Any, output: Any): return apply_to_collection(output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device) - def ddp_sync(self, module: nn.Module, data: Any, output: Any): + @staticmethod + def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) From 9d16c692c3c2d8c6f0e7408b430bca16daef1bc3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 11 Aug 2020 13:46:22 +0200 Subject: [PATCH 15/33] add save/load test --- tests/metrics/test_metrics.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index bf5966fa4db4a..7ad77864443c0 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,3 +1,4 @@ +import os import numpy as np import pytest import torch @@ -149,7 +150,7 @@ def change_and_check_device_dtype(device, dtype): DummyTensorMetric, DummyNumpyMetric(), ]) -def test_pickable(tmpdir, metric: Metric): +def test_model_pickable(tmpdir, metric: Metric): """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" tutils.set_random_master_port() @@ -166,3 +167,21 @@ def test_pickable(tmpdir, metric: Metric): model.training_step = model.training_step_using_metrics tpipes.run_model_test(trainer_options, model) + + +@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()]) +def test_saving_pickable(tmpdir, metric: Metric): + """ Make sure that metrics are pickable by saving and loading them using torch """ + x, y = torch.randn(10,), torch.randn(10,) + results_before_save = metric(x,y) + + # save metric + save_path = os.path.join(tmpdir, 'save_test.ckpt') + torch.save(metric, save_path) + + # load metric + new_metric = torch.load(save_path) + results_after_load = new_metric(x,y) + + # Check metric value is the same + assert results_before_save == results_after_load From 810d3dd98cfc570749206bfc1f91c6973a139e1e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 11 Aug 2020 14:08:20 +0200 Subject: [PATCH 16/33] add aggregate method --- pytorch_lightning/metrics/converters.py | 32 ++++++++++++++++++++++ pytorch_lightning/metrics/metric.py | 35 ++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 8a0a0258f0884..490eca51703bd 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -268,6 +268,38 @@ def sync_ddp_if_available(result: Union[torch.Tensor], return result +def gather_all_tensors_if_available(result: Union[torch.Tensor], + group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcastet all all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + world_size = torch.distributed.get_world_size(group) + + gathered_result = world_size * [torch.zeros_like(result)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + result = gathered_result + + return result + + def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 4dfc356ae7580..59f37aa2664b3 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -7,7 +7,8 @@ import numpy as np from pytorch_lightning.metrics.converters import ( - sync_ddp_if_available, convert_to_tensor, convert_to_numpy) + sync_ddp_if_available, gather_all_tensors_if_available, + convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -46,6 +47,7 @@ def __init__(self, name: str): self.register_forward_pre_hook(self.input_convert) self.register_forward_hook(self.output_convert) self.register_forward_hook(self.ddp_sync) + self.register_forward_hook(self.aggregate) self.register_forward_hook(self.compute) @abstractmethod @@ -90,6 +92,24 @@ def ddp_sync(self, data: Any, output: Any): Returns: synced output + """ + return apply_to_collection(output, torch.Tensor, gather_all_tensors_if_available, + self.reduce_group) + + @staticmethod + def aggregate(self, data: Any, output: Any): + """ + Implement aggregation of values on the same device + + Args: + + data: input to forward method + + output: output from forward method + + Returns: + aggregated values + """ return output @@ -164,6 +184,10 @@ def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) + @staticmethod + def aggregate(self, data: Any, output: Any): + return output + class TensorCollectionMetric(Metric): """ @@ -218,6 +242,11 @@ def ddp_sync(self, data: Any, output: Any): self.reduce_group, self.reduce_op) + @staticmethod + def aggregate(self, data: Any, output: Any): + return output + + class NumpyMetric(Metric): """ Base class for metric implementation operating on numpy arrays. @@ -259,3 +288,7 @@ def output_convert(self, data: Any, output: Any): def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) + + @staticmethod + def aggregate(self, data: Any, output: Any): + return output From 4923e342036ac69f2c117eeb72ea942c67a1199a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 11 Aug 2020 14:27:22 +0200 Subject: [PATCH 17/33] text clarification --- pytorch_lightning/metrics/metric.py | 66 ++++++++++++----------------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 59f37aa2664b3..12252a0e7d4db 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -24,11 +24,12 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): Metric hooks that can be implemented are: input_convert: pre-forward hook that takes care of input conversion output_convert: post-forward hook that takes care of output convertion - ddp_sync: implementation of ddp sync + ddp_sync: implementation of ddp sync, default is gather all + aggregate: implement how values should be aggregated compute: post-ddp sync for additional metric computations Call order: - input_convert -> forward -> output_convert -> ddp_sync -> compute + input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute """ @@ -50,6 +51,20 @@ def __init__(self, name: str): self.register_forward_hook(self.aggregate) self.register_forward_hook(self.compute) + @staticmethod + def input_convert(self, data: Any): + """ + Implement how the inputs should be casted before calling forward + + Args: + + data: input to forward method + + Returns: + casted data + """ + return data + @abstractmethod def forward(self, *args, **kwargs): """ @@ -62,9 +77,9 @@ def forward(self, *args, **kwargs): raise NotImplementedError @staticmethod - def compute(self, data: Any, output: Any): + def output_convert(self, data: Any, output: Any): """ - Implement additionally metric computations to be done after the ddp sync + Implement how outputs from forward should be casted Args: @@ -73,8 +88,7 @@ def compute(self, data: Any, output: Any): output: output from forward method Returns: - final metric value - + casted outputs """ return output @@ -87,7 +101,7 @@ def ddp_sync(self, data: Any, output: Any): data: input to forward method - output: output from forward method + output: output from the `output_convert` hook Returns: synced output @@ -105,7 +119,7 @@ def aggregate(self, data: Any, output: Any): data: input to forward method - output: output from forward method + output: output from the `ddp_sync` hook Returns: aggregated values @@ -114,32 +128,19 @@ def aggregate(self, data: Any, output: Any): return output @staticmethod - def input_convert(self, data: Any): - """ - Implement how the inputs should be casted before calling forward - - Args: - - data: input to forward method - - Returns: - casted data - """ - return data - - @staticmethod - def output_convert(self, data: Any, output: Any): + def compute(self, data: Any, output: Any): """ - Implement how outputs from forward should be casted + Implement additionally metric computations to be done after the ddp sync Args: data: input to forward method - output: output from forward method + output: output from the `aggregate` hook Returns: - casted outputs + final metric value + """ return output @@ -184,10 +185,6 @@ def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) - @staticmethod - def aggregate(self, data: Any, output: Any): - return output - class TensorCollectionMetric(Metric): """ @@ -242,11 +239,6 @@ def ddp_sync(self, data: Any, output: Any): self.reduce_group, self.reduce_op) - @staticmethod - def aggregate(self, data: Any, output: Any): - return output - - class NumpyMetric(Metric): """ Base class for metric implementation operating on numpy arrays. @@ -288,7 +280,3 @@ def output_convert(self, data: Any, output: Any): def ddp_sync(self, data: Any, output: Any): return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, self.reduce_group, self.reduce_op) - - @staticmethod - def aggregate(self, data: Any, output: Any): - return output From ae60d3dae88eb3f446a6291629797a6d2c40199d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 11 Aug 2020 14:34:20 +0200 Subject: [PATCH 18/33] fix doctest --- pytorch_lightning/metrics/metric.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 12252a0e7d4db..f60421739dbc8 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -107,8 +107,7 @@ def ddp_sync(self, data: Any, output: Any): synced output """ - return apply_to_collection(output, torch.Tensor, gather_all_tensors_if_available, - self.reduce_group) + return output @staticmethod def aggregate(self, data: Any, output: Any): From 796d913ae4ae105ee3e3cdd7af82e6019e6c68b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 22 Aug 2020 11:28:47 +0200 Subject: [PATCH 19/33] Apply suggestions from code review --- pytorch_lightning/metrics/converters.py | 2 +- pytorch_lightning/metrics/metric.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 490eca51703bd..4e92b5efb67c7 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -272,7 +272,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that - is broadcastet all all processes + is broadcasted to all processes Args: result: the value to sync diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f60421739dbc8..9d73ce370a495 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -18,17 +18,18 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): Abstract base class for metric implementation. Should be used to implement metrics that - 1. Return multiple Outputs + 1. Return multiple outputs 2. Handle their own DDP sync Metric hooks that can be implemented are: input_convert: pre-forward hook that takes care of input conversion - output_convert: post-forward hook that takes care of output convertion + output_convert: post-forward hook that takes care of output conversion ddp_sync: implementation of ddp sync, default is gather all aggregate: implement how values should be aggregated compute: post-ddp sync for additional metric computations Call order: + input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute """ From 8a3a12870f28afcea4c616a395405443a49124a2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 24 Aug 2020 16:10:36 +0200 Subject: [PATCH 20/33] change test to results obj --- pytorch_lightning/metrics/metric.py | 10 +++++----- tests/base/model_train_steps.py | 7 +++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f60421739dbc8..2d40adf2c9c44 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -22,11 +22,11 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): 2. Handle their own DDP sync Metric hooks that can be implemented are: - input_convert: pre-forward hook that takes care of input conversion - output_convert: post-forward hook that takes care of output convertion - ddp_sync: implementation of ddp sync, default is gather all - aggregate: implement how values should be aggregated - compute: post-ddp sync for additional metric computations + * input_convert: pre-forward hook that takes care of input conversion + * output_convert: post-forward hook that takes care of output convertion + * ddp_sync: implementation of ddp sync, default is gather all + * aggregate: implement how values should be aggregated + * compute: post-ddp sync for additional metric computations Call order: input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 67faa1c24c36d..04343dd24aec1 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -189,7 +189,6 @@ def training_step_using_metrics(self, batch, batch_idx, optimizer_idx=None): # call metric val = self.metric(x, y) - output = OrderedDict({'loss': loss_val, - 'progress_bar': {'metric_val': val}, - 'log': {'metric_val': val}}) - return output + result = TrainResult(minimize=loss_val) + result.log('metric_val', val) + return result \ No newline at end of file From 26b4bb412666bd102fc985d77870a89e6e2156c2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 24 Aug 2020 16:22:24 +0200 Subject: [PATCH 21/33] fix docs --- pytorch_lightning/metrics/metric.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2d40adf2c9c44..5a9f58b5ce0a4 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -18,17 +18,20 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): Abstract base class for metric implementation. Should be used to implement metrics that - 1. Return multiple Outputs - 2. Handle their own DDP sync - Metric hooks that can be implemented are: + 1. Return multiple Outputs + 2. Handle their own DDP sync + + Metric hooks that can be implemented are + * input_convert: pre-forward hook that takes care of input conversion * output_convert: post-forward hook that takes care of output convertion * ddp_sync: implementation of ddp sync, default is gather all * aggregate: implement how values should be aggregated * compute: post-ddp sync for additional metric computations - Call order: + Call order + input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute """ From 8aae0be23d7acb807b73c0eddac70250bd046192 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 25 Aug 2020 20:46:44 +0200 Subject: [PATCH 22/33] formatting Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 341949674f1f6..9b1bd4bafd158 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -99,9 +99,7 @@ def output_convert(self, data: Any, output: Any): Implement how outputs from forward should be casted Args: - data: input to forward method - output: output from forward method Returns: From b783ec5fc15f4b37af4a3ed4defdf0868aa78db1 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:50:28 +0530 Subject: [PATCH 23/33] formatting --- pytorch_lightning/metrics/metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 9b1bd4bafd158..72d95c2a6d0e1 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -147,9 +147,7 @@ def compute(self, data: Any, output: Any): Implement additionally metric computations to be done after the ddp sync Args: - data: input to forward method - output: output from the `aggregate` hook Returns: From e0ba5578d6d43b4a1ec5945c1ccedfaca54c3c58 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:50:43 +0530 Subject: [PATCH 24/33] formatting --- pytorch_lightning/metrics/metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 72d95c2a6d0e1..85ef6321b5919 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -130,9 +130,7 @@ def aggregate(self, data: Any, output: Any): Implement aggregation of values on the same device Args: - data: input to forward method - output: output from the `ddp_sync` hook Returns: From da4e94d3e012686fe5d5c27caa16921772031d58 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:51:01 +0530 Subject: [PATCH 25/33] formatting --- pytorch_lightning/metrics/metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 85ef6321b5919..f4f6fdde1daf7 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -113,9 +113,7 @@ def ddp_sync(self, data: Any, output: Any): Implement how the outputs from forward should be synced Args: - data: input to forward method - output: output from the `output_convert` hook Returns: From a27aebddbe525ed4a45a064c743a4dfda2ff7b77 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:51:15 +0530 Subject: [PATCH 26/33] formatting --- pytorch_lightning/metrics/metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f4f6fdde1daf7..5f61a50e6cd25 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -74,7 +74,6 @@ def input_convert(self, data: Any): Implement how the inputs should be casted before calling forward Args: - data: input to forward method Returns: From 9641f84b73dda0110d38bf7757abe654d438d0d0 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:51:27 +0530 Subject: [PATCH 27/33] formatting --- pytorch_lightning/metrics/converters.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 94c16ca7d59f9..a41a621c905da 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -92,9 +92,7 @@ def convert_to_tensor(data: Any, dtype=None, device=None) -> Any: Args: data: the data to convert to tensor - dtype: data type to convert to - device: device to cast to Return: From 9fabc89f21e07971cdb2d4610e47850ada78f398 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 26 Aug 2020 00:54:36 +0530 Subject: [PATCH 28/33] pep --- tests/metrics/test_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 7ad77864443c0..538b43eddd3fa 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -173,7 +173,7 @@ def test_model_pickable(tmpdir, metric: Metric): def test_saving_pickable(tmpdir, metric: Metric): """ Make sure that metrics are pickable by saving and loading them using torch """ x, y = torch.randn(10,), torch.randn(10,) - results_before_save = metric(x,y) + results_before_save = metric(x, y) # save metric save_path = os.path.join(tmpdir, 'save_test.ckpt') @@ -181,7 +181,7 @@ def test_saving_pickable(tmpdir, metric: Metric): # load metric new_metric = torch.load(save_path) - results_after_load = new_metric(x,y) + results_after_load = new_metric(x, y) # Check metric value is the same assert results_before_save == results_after_load From 619ffc242316a3312820d189dd5273e856ff89ab Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 25 Aug 2020 21:53:24 +0200 Subject: [PATCH 29/33] Update CHANGELOG.md --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8c4b00e7357d..27c77cafd4db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) ### Changed @@ -55,8 +56,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added warning when changing monitor and using results obj ([#3014](https://github.com/PyTorchLightning/pytorch-lightning/pull/3014)) - Added a hook `transfer_batch_to_device` to the `LightningDataModule` ([#3038](https://github.com/PyTorchLightning/pytorch-lightning/pull/3038)) -- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) - ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) From b03ce8f82f4ab3c01c6e3876b0ec4d0238153148 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 26 Aug 2020 08:10:09 +0200 Subject: [PATCH 30/33] suggestions --- tests/base/model_train_steps.py | 2 +- tests/metrics/test_metrics.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 04343dd24aec1..2995001a777fd 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -176,7 +176,7 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result): return result - def training_step_using_metrics(self, batch, batch_idx, optimizer_idx=None): + def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" # forward pass x, y = batch diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 538b43eddd3fa..35512d9c4efa9 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -148,7 +148,7 @@ def change_and_check_device_dtype(device, dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.parametrize("metric", [ DummyTensorMetric, - DummyNumpyMetric(), + DummyNumpyMetric, ]) def test_model_pickable(tmpdir, metric: Metric): """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" @@ -164,7 +164,7 @@ def test_model_pickable(tmpdir, metric: Metric): model = EvalModelTemplate() model.metric = metric() - model.training_step = model.training_step_using_metrics + model.training_step = model.training_step__using_metrics tpipes.run_model_test(trainer_options, model) From ac169dfbd92528a41adcec7eb8d35f3a80f6fd6a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 26 Aug 2020 09:47:30 +0200 Subject: [PATCH 31/33] fix tests --- tests/metrics/test_metrics.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 35512d9c4efa9..a9878abd0941e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -4,9 +4,9 @@ import torch import tests.base.develop_utils as tutils -import tests.base.develop_pipelines as tpipes from tests.base import EvalModelTemplate from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric +from pytorch_lightning import Trainer class DummyTensorMetric(TensorMetric): @@ -146,11 +146,9 @@ def change_and_check_device_dtype(device, dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.parametrize("metric", [ - DummyTensorMetric, - DummyNumpyMetric, -]) -def test_model_pickable(tmpdir, metric: Metric): +@pytest.mark.parametrize("distributed_backend", ['ddp', 'ddp_spawn']) +@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric]) +def test_model_pickable(tmpdir, distributed_backend: str, metric: Metric): """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" tutils.set_random_master_port() @@ -159,14 +157,18 @@ def test_model_pickable(tmpdir, metric: Metric): max_epochs=1, limit_train_batches=10, gpus=[0, 1], - distributed_backend='ddp_spawn', + distributed_backend=distributed_backend, ) model = EvalModelTemplate() model.metric = metric() model.training_step = model.training_step__using_metrics - tpipes.run_model_test(trainer_options, model) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # correct result and ok accuracy + assert result == 1, 'amp + ddp model failed to complete' @pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()]) From c1a13892519efeac1dbbd0ccb1d360d133dc1706 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 26 Aug 2020 09:52:46 +0200 Subject: [PATCH 32/33] fix pep8 --- tests/base/model_train_steps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 2995001a777fd..feb08fddda467 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -191,4 +191,4 @@ def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None): result = TrainResult(minimize=loss_val) result.log('metric_val', val) - return result \ No newline at end of file + return result From 847d0edc6962ee147e26c0d4cfbc1e3c5837ad7c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 26 Aug 2020 10:54:26 +0200 Subject: [PATCH 33/33] fix tests --- tests/metrics/test_metrics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index a9878abd0941e..5985745bfa070 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -146,9 +146,8 @@ def change_and_check_device_dtype(device, dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.parametrize("distributed_backend", ['ddp', 'ddp_spawn']) @pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric]) -def test_model_pickable(tmpdir, distributed_backend: str, metric: Metric): +def test_model_pickable(tmpdir, metric: Metric): """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" tutils.set_random_master_port() @@ -157,7 +156,7 @@ def test_model_pickable(tmpdir, distributed_backend: str, metric: Metric): max_epochs=1, limit_train_batches=10, gpus=[0, 1], - distributed_backend=distributed_backend, + distributed_backend='ddp_spawn', ) model = EvalModelTemplate() @@ -168,7 +167,7 @@ def test_model_pickable(tmpdir, distributed_backend: str, metric: Metric): result = trainer.fit(model) # correct result and ok accuracy - assert result == 1, 'amp + ddp model failed to complete' + assert result == 1, 'ddp model failed to complete' @pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])