diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ac24dcee0a1e1..d0ba2caa69d22 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -54,6 +54,11 @@ jobs: run: | python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)" + # versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996) + - name: Setup MacOS Minimal + if: runner.os == 'macOS' && matrix.requires ='minimal' + run : | + python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch>=1.4') ; open('requirements.txt', 'w').write(req)" - name: Set min. dependencies if: matrix.requires == 'minimal' run: | @@ -137,4 +142,4 @@ jobs: - name: Statistics if: success() run: | - coverage report \ No newline at end of file + coverage report diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c35dcdca5034..cd9389c6d7078 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) + - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). ### Changed diff --git a/README.md b/README.md index 02ca78f68dcbf..277396eca5d3a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ removed until codecov badge isn't empy. likely a config error showing nothing on | Linux py3.6 [CPU] | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | | Linux py3.7 [GPU] | - | - | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) | | Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| OSX py3.6 / py3.7 / py3.8| [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | +| OSX py3.6 / py3.7 / py3.8| - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | | Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | diff --git a/docs/source/index.rst b/docs/source/index.rst index 34df3af871b59..f7bee4395b637 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ PyTorch Lightning Documentation hooks lightning-module loggers + metrics trainer .. toctree:: @@ -115,6 +116,7 @@ Indices and tables api/pytorch_lightning.core api/pytorch_lightning.callbacks api/pytorch_lightning.loggers + api/pytorch_lightning.metrics api/pytorch_lightning.overrides api/pytorch_lightning.profiler api/pytorch_lightning.trainer diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst new file mode 100644 index 0000000000000..6f70a3c73f2d0 --- /dev/null +++ b/docs/source/metrics.rst @@ -0,0 +1,4 @@ +.. automodule:: pytorch_lightning.metrics + :members: + :noindex: + :exclude-members: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 784ae9c3a45fa..662c05a29338d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -18,7 +18,7 @@ from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams -from pytorch_lightning.core.properties import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py new file mode 100644 index 0000000000000..cd721851307df --- /dev/null +++ b/pytorch_lightning/metrics/__init__.py @@ -0,0 +1,24 @@ +""" +Metrics +======= + +Metrics are generally used to monitor model performance. + +The following package aims to provide the most convenient ones as well +as a structure to implement your custom metrics for all the fancy research +you want to do. + +For native PyTorch implementations of metrics, it is recommended to use +the :class:`TensorMetric` which handles automated DDP syncing and conversions +to tensors for all inputs and outputs. + +If your metrics implementation works on numpy, just use the +:class:`NumpyMetric`, which handles the automated conversion of +inputs to and outputs from numpy as well as automated ddp syncing. + +.. warning:: Employing numpy in your metric calculation might slow + down your training substantially, since every metric computation + requires a GPU sync to convert tensors to numpy. + + +""" diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py new file mode 100644 index 0000000000000..7681be589f7fe --- /dev/null +++ b/pytorch_lightning/metrics/converters.py @@ -0,0 +1,230 @@ +""" +This file provides functions and decorators for automated input and output +conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to +sync tensors between different processes in a DDP scenario, when needed. +""" + +import sys +import numbers +from typing import Union, Any, Callable, Optional + +import numpy as np +import torch +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all inputs of a function. + Args: + func_to_apply: the function to apply to the inputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(func_to_decorate): + # actual function applying the give function to inputs + def new_func(*args, **kwargs): + args = func_to_apply(args, *dec_args, **dec_kwargs) + kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs) + return func_to_decorate(*args, **kwargs) + + return new_func + + return decorator_fn + + +def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all outputs of a function. + Args: + func_to_apply: the function to apply to the outputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(function_to_decorate): + # actual function applying the give function to outputs + def new_func(*args, **kwargs): + result = function_to_decorate(*args, **kwargs) + return func_to_apply(result, *dec_args, **dec_kwargs) + + return new_func + + return decorator_fn + + +def _convert_to_tensor(data: Any) -> Any: + """ + Maps all kind of collections and numbers to tensors. + + Args: + data: the data to convert to tensor + + Returns: + the converted data + + """ + if isinstance(data, numbers.Number): + return torch.tensor([data]) + # is not array of object + elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: + return torch.from_numpy(data) + elif isinstance(data, torch.Tensor): + return data + + raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") + + +def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: + """Convert all tensors and numpy arrays to numpy arrays. + Args: + data: the tensor or array to convert to numpy + + Returns: + the resulting numpy array + + """ + if isinstance(data, torch.Tensor): + return data.cpu().detach().numpy() + elif isinstance(data, numbers.Number): + return np.array([data]) + elif isinstance(data, np.ndarray): + return data + + raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__) + + +def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator handling the argument conversion for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to tensors. + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # applies collection conversion from tensor to numpy to all inputs + # we need to include numpy arrays here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + # converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric) + func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + return func_convert_in_out + + +def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator Handling the argument conversion for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # converts all inputs to tensor if possible + # we need to include tensors here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate) + # convert all outputs to tensor if possible + return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + + +def _sync_ddp_if_available(result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None, + ) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + + Returns: + reduced value + + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + if reduce_op is None: + reduce_op = torch.distributed.ReduceOp.SUM + + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, + async_op=False) + + return result + + +def numpy_metric(group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + """ + This decorator shall be used on all function metrics working on numpy arrays. + + It handles the argument conversion and DDP reduction for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to tensors. + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate)) + + return decorator_fn + + +def tensor_metric(group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: + """ + This decorator shall be used on all function metrics working on tensors. + + It handles the argument conversion and DDP reduction for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors. + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate)) + + return decorator_fn diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py new file mode 100644 index 0000000000000..5247084498559 --- /dev/null +++ b/pytorch_lightning/metrics/metric.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin + +__all__ = ['Metric', 'TensorMetric', 'NumpyMetric'] + + +class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): + """ + Abstract base class for metric implementation. + + Should be used to implement metrics that + 1. Return multiple Outputs + 2. Handle their own DDP sync + """ + def __init__(self, name: str): + """ + Args: + name: the metric's name + + """ + super().__init__() + self.name = name + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') + + @abstractmethod + def forward(self, *args, **kwargs) -> torch.Tensor: + """ + Implements the actual metric computation. + + Returns: + metric value + + """ + raise NotImplementedError + + +class TensorMetric(Metric): + """ + Base class for metric implementation operating directly on tensors. + All inputs and outputs will be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = None, + reduce_op: Optional[Any] = None): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = tensor_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype, non_blocking=True) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) + + +class NumpyMetric(Metric): + """ + Base class for metric implementation operating on numpy arrays. + All inputs will be casted to numpy if necessary and all outputs will + be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = None, + reduce_op: Optional[Any] = None): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = numpy_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype, non_blocking=True) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py new file mode 100644 index 0000000000000..724715c3d8607 --- /dev/null +++ b/pytorch_lightning/utilities/apply_func.py @@ -0,0 +1,36 @@ +from collections import Mapping, Sequence +from typing import Any, Callable, Union + + +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + + """ + elem_type = type(data) + + # Breaking condition + if isinstance(data, dtype): + return function(data, *args, **kwargs) + + # Recursively apply to collection items + elif isinstance(data, Mapping): + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) + for k, v in data.items()}) + elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + elif isinstance(data, Sequence) and not isinstance(data, str): + return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + + # data is neither of dtype, nor a collection + return data diff --git a/pytorch_lightning/core/properties.py b/pytorch_lightning/utilities/device_dtype_mixin.py similarity index 100% rename from pytorch_lightning/core/properties.py rename to pytorch_lightning/utilities/device_dtype_mixin.py diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py new file mode 100644 index 0000000000000..9abc11d4b07ad --- /dev/null +++ b/tests/metrics/test_converters.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import tests.base.utils as tutils +from pytorch_lightning.metrics.converters import ( + _apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy, + _numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric) + + +@pytest.mark.parametrize(['args', 'kwargs'], + [pytest.param([], {}), + pytest.param([1., 2.], {}), + pytest.param([], {'a': 1., 'b': 2.}), + pytest.param([1., 2.], {'a': 1., 'b': 2.})]) +def test_apply_to_inputs(args, kwargs): + def apply_fn(inputs, factor): + if isinstance(inputs, (float, int)): + return inputs * factor + elif isinstance(inputs, dict): + return {k: apply_fn(v, factor) for k, v in inputs.items()} + elif isinstance(inputs, (tuple, list)): + return [apply_fn(x, factor) for x in inputs] + + @_apply_to_inputs(apply_fn, factor=2.) + def test_fn(*func_args, **func_kwargs): + return func_args, func_kwargs + + result_args, result_kwargs = test_fn(*args, **kwargs) + assert isinstance(result_args, (list, tuple)) + assert isinstance(result_kwargs, dict) + assert len(result_args) == len(args) + assert len(result_kwargs) == len(kwargs) + assert all([k in result_kwargs for k in kwargs.keys()]) + for arg, result_arg in zip(args, result_args): + assert arg * 2. == result_arg + + for key in kwargs.keys(): + arg = kwargs[key] + result_arg = result_kwargs[key] + assert arg * 2. == result_arg + + +def test_apply_to_outputs(): + def apply_fn(inputs, additional_str): + return str(inputs) + additional_str + + @_apply_to_outputs(apply_fn, additional_str='_str') + def test_fn(*args, **kwargs): + return 'dummy' + + assert test_fn() == 'dummy_str' + + +def test_convert_to_tensor(): + for test_item in [1., np.array([1.])]: + result_tensor = _convert_to_tensor(test_item) + assert isinstance(result_tensor, torch.Tensor) + assert result_tensor.item() == 1. + + +def test_convert_to_numpy(): + for test_item in [1., torch.tensor([1.])]: + result = _convert_to_numpy(test_item) + assert isinstance(result, np.ndarray) + assert result.item() == 1. + + +def test_numpy_metric_conversion(): + @_numpy_metric_conversion + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def test_tensor_metric_conversion(): + @_tensor_metric_conversion + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def setup_ddp(rank, worldsize, ): + import os + + os.environ['MASTER_ADDR'] = 'localhost' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def ddp_test_fn(rank, worldsize): + setup_ddp(rank, worldsize) + tensor = torch.tensor([1.], device='cuda:0') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert reduced_tensor.item() == dist.get_world_size(), \ + 'Sync-Reduce does not work properly with DDP and Tensors' + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_sync_reduce_ddp(): + """Make sure sync-reduce works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize) + + +def test_sync_reduce_simple(): + """Make sure sync-reduce works without DDP""" + tensor = torch.tensor([1.], device='cpu') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert torch.allclose(tensor, reduced_tensor), \ + 'Sync-Reduce does not work properly without DDP and Tensors' + + +def _test_tensor_metric(is_ddp: bool): + @tensor_metric() + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_tensor_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_tensor_metric(True) + + +def test_tensor_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + + world_size = 2 + mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) + + +def test_tensor_metric_simple(): + _test_tensor_metric(False) + + +def _test_numpy_metric(is_ddp: bool): + @numpy_metric() + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_numpy_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_numpy_metric(True) + + +def test_numpy_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + world_size = 2 + mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) + + +def test_numpy_metric_simple(): + _test_tensor_metric(False) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000000..e83a9d97b7a6c --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,85 @@ +import numpy as np +import torch + +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric + + +class DummyTensorMetric(TensorMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, torch.Tensor) + assert isinstance(input2, torch.Tensor) + return 1. + + +class DummyNumpyMetric(NumpyMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, np.ndarray) + assert isinstance(input2, np.ndarray) + return 1. + + +def _test_metric(metric: Metric): + input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + + def change_and_check_device_dtype(device, dtype): + metric.to(device=device, dtype=dtype) + + metric_val = metric(input1, input2) + assert isinstance(metric_val, torch.Tensor) + + if device is not None: + assert metric.device in [device, torch.device(device)] + assert metric_val.device in [device, torch.device(device)] + + if dtype is not None: + assert metric.dtype == dtype + assert metric_val.dtype == dtype + + devices = [None, 'cpu'] + if torch.cuda.is_available(): + devices += ['cuda:0'] + + for device in devices: + for dtype in [None, torch.float32, torch.float64]: + change_and_check_device_dtype(device=device, dtype=dtype) + + if torch.cuda.is_available(): + metric.cuda(0) + assert metric.device == torch.device('cuda', index=0) + assert metric(input1, input2).device == torch.device('cuda', index=0) + + metric.cpu() + assert metric.device == torch.device('cpu') + assert metric(input1, input2).device == torch.device('cpu') + + metric.type(torch.int8) + assert metric.dtype == torch.int8 + assert metric(input1, input2).dtype == torch.int8 + + metric.float() + assert metric.dtype == torch.float32 + assert metric(input1, input2).dtype == torch.float32 + + metric.double() + assert metric.dtype == torch.float64 + assert metric(input1, input2).dtype == torch.float64 + + if torch.cuda.is_available(): + metric.cuda() + metric.half() + assert metric.dtype == torch.float16 + assert metric(input1, input2).dtype == torch.float16 + + +def test_tensor_metric(): + _test_metric(DummyTensorMetric()) + + +def test_numpy_metric(): + _test_metric(DummyNumpyMetric()) diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py new file mode 100644 index 0000000000000..dce1e56e2b332 --- /dev/null +++ b/tests/utilities/test_apply_func.py @@ -0,0 +1,66 @@ +import numbers +from collections import namedtuple + +import numpy as np +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def test_recursive_application_to_collection(): + ntc = namedtuple('Foo', ['bar']) + + to_reduce = { + 'a': torch.tensor([1.]), # Tensor + 'b': [torch.tensor([2.])], # list + 'c': (torch.tensor([100.]),), # tuple + 'd': ntc(bar=5.), # named tuple + 'e': np.array([10.]), # numpy array + 'f': 'this_is_a_dummy_str', # string + 'g': 12. # number + } + + expected_result = { + 'a': torch.tensor([2.]), + 'b': [torch.tensor([4.])], + 'c': (torch.tensor([200.]),), + 'd': ntc(bar=torch.tensor([10.])), + 'e': np.array([20.]), + 'f': 'this_is_a_dummy_str', + 'g': 24. + } + + reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), + lambda x: x * 2) + + assert isinstance(reduced, dict), ' Type Consistency of dict not preserved' + assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved' + assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \ + 'At least one type was not correctly preserved' + + assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor' + assert torch.allclose(expected_result['a'], reduced['a']), \ + 'Reduction of a tensor does not yield the expected value' + + assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list' + assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \ + 'At least one value of list reduction did not come out as expected' + + assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple' + assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \ + 'At least one value of tuple reduction did not come out as expected' + + assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given' + assert isinstance(reduced['d'].bar, numbers.Number), \ + 'Failure in type promotion while reducing fields of named tuples' + assert reduced['d'].bar == expected_result['d'].bar + + assert isinstance(reduced['e'], np.ndarray), 'Type Promotion in reduction of numpy arrays failed' + assert reduced['e'] == expected_result['e'], \ + 'Reduction of numpy array did not yield the expected result' + + assert isinstance(reduced['f'], str), 'A string should not be reduced' + assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' + + assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' + assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'