From 4c6c4e02d49c02891a310643f0e1506c5b44bfc0 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 3 Apr 2020 21:10:40 +0200 Subject: [PATCH 01/15] New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec * Apply suggestions from code review Co-Authored-By: Jirka Borovec * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 + docs/source/index.rst | 4 +- docs/source/metrics.rst | 4 + pytorch_lightning/metrics/__init__.py | 5 + pytorch_lightning/metrics/converters.py | 223 +++++++++++++++++++ pytorch_lightning/metrics/metric.py | 260 ++++++++++++++++++++++ pytorch_lightning/utilities/apply_func.py | 36 +++ tests/metrics/__init__.py | 0 tests/metrics/test_converters.py | 214 ++++++++++++++++++ tests/metrics/test_metrics.py | 85 +++++++ tests/utilities/__init__.py | 0 tests/utilities/test_apply_func.py | 66 ++++++ 12 files changed, 899 insertions(+), 1 deletion(-) create mode 100644 docs/source/metrics.rst create mode 100644 pytorch_lightning/metrics/__init__.py create mode 100644 pytorch_lightning/metrics/converters.py create mode 100644 pytorch_lightning/metrics/metric.py create mode 100644 pytorch_lightning/utilities/apply_func.py create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/test_converters.py create mode 100644 tests/metrics/test_metrics.py create mode 100644 tests/utilities/__init__.py create mode 100644 tests/utilities/test_apply_func.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d7af014edecb..9a1802fa67809 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## Metrics (will be added to unreleased once the metric branch was finished) +- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326)) + ## [unreleased] - YYYY-MM-DD ### Added diff --git a/docs/source/index.rst b/docs/source/index.rst index 1e11f7a0e9487..68b9b2abcb263 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ PyTorch Lightning Documentation hooks lightning-module loggers + metrics trainer .. toctree:: @@ -105,7 +106,8 @@ Indices and tables pytorch_lightning.core pytorch_lightning.callbacks pytorch_lightning.loggers + pytorch_lightning.metrics pytorch_lightning.overrides pytorch_lightning.profiler pytorch_lightning.trainer - pytorch_lightning.utilities \ No newline at end of file + pytorch_lightning.utilities diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst new file mode 100644 index 0000000000000..6f70a3c73f2d0 --- /dev/null +++ b/docs/source/metrics.rst @@ -0,0 +1,4 @@ +.. automodule:: pytorch_lightning.metrics + :members: + :noindex: + :exclude-members: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py new file mode 100644 index 0000000000000..18522e0dda94b --- /dev/null +++ b/pytorch_lightning/metrics/__init__.py @@ -0,0 +1,5 @@ +""" +Metrics +======= +TODO +""" diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py new file mode 100644 index 0000000000000..8162876fc3b00 --- /dev/null +++ b/pytorch_lightning/metrics/converters.py @@ -0,0 +1,223 @@ +""" +This file provides functions and decorators for automated input and output +conversion to/from numpy.ndarray and torch.Tensor as well as utilities to +sync tensors between different processes in a DDP scenario, when needed. +""" + +import numbers +from typing import Union, Any, Callable + +import numpy as np +import torch +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all inputs of a function. + Args: + func_to_apply: the function to apply to the inputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(func_to_decorate): + # actual function applying the give function to inputs + def new_func(*args, **kwargs): + args = func_to_apply(args, *dec_args, **dec_kwargs) + kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs) + return func_to_decorate(*args, **kwargs) + + return new_func + + return decorator_fn + + +def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all outputs of a function. + Args: + func_to_apply: the function to apply to the outputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(function_to_decorate): + # actual function applying the give function to outputs + def new_func(*args, **kwargs): + result = function_to_decorate(*args, **kwargs) + return func_to_apply(result, *dec_args, **dec_kwargs) + + return new_func + + return decorator_fn + + +def _convert_to_tensor(data: Any) -> Any: + """ + Maps all kind of collections and numbers to tensors. + + Args: + data: the data to convert to tensor + + Returns: + the converted data + + """ + if isinstance(data, numbers.Number): + return torch.tensor([data]) + # is not array of object + elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: + return torch.from_numpy(data) + elif isinstance(data, torch.Tensor): + return data + + raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__) + + +def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: + """Convert all tensors and numpy arrays to numpy arrays. + Args: + data: the tensor or array to convert to numpy + + Returns: + the resulting numpy array + + """ + if isinstance(data, torch.Tensor): + return data.cpu().detach().numpy() + elif isinstance(data, numbers.Number): + return np.array([data]) + elif isinstance(data, np.ndarray): + return data + + raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__) + + +def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator Handling the argument conversion for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to Tensors + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # applies collection conversion from tensor to numpy to all inputs + # we need to include numpy arrays here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + # converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric) + func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + return func_convert_in_out + + +def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator Handling the argument conversion for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # converts all inputs to tensor if possible + # we need to include tensors here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate) + # convert all outputs to tensor if possible + return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + + +def _sync_ddp_if_available(result: Union[torch.Tensor], + group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM, + ) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + reduced value + + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, + async_op=False) + + return result + + +def numpy_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: + """ + This decorator shall be used on all function metrics working on numpy arrays. + + It handles the argument conversion and DDP reduction for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to Tensors. + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate)) + + return decorator_fn + + +def tensor_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: + """ + This decorator shall be used on all function metrics working on tensors. + + It handles the argument conversion and DDP reduction for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors . + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate)) + + return decorator_fn diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py new file mode 100644 index 0000000000000..50853105f94f9 --- /dev/null +++ b/pytorch_lightning/metrics/metric.py @@ -0,0 +1,260 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric +from pytorch_lightning.utilities.apply_func import apply_to_collection + +__all__ = ['Metric', 'TensorMetric', 'NumpyMetric'] + + +class Metric(torch.nn.Module, ABC): + """ + Abstract Base Class for metric implementation. + + Should be used to implement metrics that + 1. Return multiple Outputs + 2. Handle their own DDP sync + """ + def __init__(self, name: str): + """ + Args: + name: the metric's name + + """ + super().__init__() + self.name = name + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') + + @property + def dtype(self) -> Union[str, torch.dtype]: + return self._dtype + + @dtype.setter + def dtype(self, new_dtype: Union[str, torch.dtype]): + # necessary to avoid infinite recursion + raise RuntimeError('Cannot set the dtype explicitly. Please use metric.to(new_dtype).') + + @property + def device(self) -> Union[str, torch.device]: + return self._device + + @device.setter + def device(self, new_device: Union[str, torch.device]): + # Necessary to avoid infinite recursion + raise RuntimeError('Cannot set the device explicitly. Please use metric.to(new_device).') + + @abstractmethod + def forward(self, *args, **kwargs) -> torch.Tensor: + """ + Implements the actual metric computation. + + Returns: + metric value + + """ + raise NotImplementedError + + def to(self, *args, **kwargs) -> torch.nn.Module: + """Moves and/or casts the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + + .. function:: to(dtype, non_blocking=False) + + .. function:: to(tensor, non_blocking=False) + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point desired :attr:`dtype` s. In addition, this method will + only cast the floating point parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + Note: + This method modifies the module in-place. + + Args: + device: the desired device of the parameters + and buffers in this module + dtype: the desired floating point type of + the floating point parameters and buffers in this module + tensor: Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + + Returns: + Module: self + + Example:: + >>> class ExampleMetric(Metric): + ... def __init__(self, weight: torch.Tensor): + ... super().__init__('example') + ... self.register_buffer('weight', weight) + ... def forward(self, pred, target) -> torch.Tensor: + ... return (pred - target) * self.weight + >>> _ = torch.manual_seed(0) + >>> metric = ExampleMetric(torch.rand(3, 4)) + >>> metric.weight + tensor([[0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017]]) + >>> metric.to(torch.double) #doctest: +ELLIPSIS + ExampleMetric() + >>> metric.weight + tensor([[...]], dtype=torch.float64) + >>> cpu = torch.device('cpu') + >>> metric.to(cpu, dtype=torch.half, non_blocking=True) + ExampleMetric() + >>> metric.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + >>> metric.to(cpu) + ExampleMetric() + >>> metric.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + + + """ + device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self._device = device + + if dtype is not None: + self._dtype = dtype + + return super().to(*args, **kwargs) + + def cuda(self, device: Optional[int] = None) -> torch.nn.Module: + """Moves all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: + """ + + self._device = torch.device('cuda', index=device) + return super().cuda(device=device) + + def cpu(self) -> torch.nn.Module: + """Moves all model parameters and buffers to the CPU. + + Returns: + Module: self + """ + self._device = torch.device('cpu') + return super().cpu() + + def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: + """Casts all parameters and buffers to :attr:`dst_type`. + + Arguments: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + self._dtype = dst_type + return super().type(dst_type=dst_type) + + def float(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to float datatype. + + Returns: + Module: self + """ + self._dtype = torch.float + return super().float() + + def double(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``double`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.double + return super().double() + + def half(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``half`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.half + return super().half() + + +class TensorMetric(Metric): + """ + Base class for metric implementation operating directly on tensors. + All inputs and outputs will be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = torch.distributed.group.WORLD, + reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = tensor_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) + + +class NumpyMetric(Metric): + """ + Base class for metric implementation operating on numpy arrays. + All inputs will be casted to numpy if necessary and all outputs will + be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = torch.distributed.group.WORLD, + reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = numpy_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py new file mode 100644 index 0000000000000..724715c3d8607 --- /dev/null +++ b/pytorch_lightning/utilities/apply_func.py @@ -0,0 +1,36 @@ +from collections import Mapping, Sequence +from typing import Any, Callable, Union + + +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + + """ + elem_type = type(data) + + # Breaking condition + if isinstance(data, dtype): + return function(data, *args, **kwargs) + + # Recursively apply to collection items + elif isinstance(data, Mapping): + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) + for k, v in data.items()}) + elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + elif isinstance(data, Sequence) and not isinstance(data, str): + return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + + # data is neither of dtype, nor a collection + return data diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py new file mode 100644 index 0000000000000..9abc11d4b07ad --- /dev/null +++ b/tests/metrics/test_converters.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import tests.base.utils as tutils +from pytorch_lightning.metrics.converters import ( + _apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy, + _numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric) + + +@pytest.mark.parametrize(['args', 'kwargs'], + [pytest.param([], {}), + pytest.param([1., 2.], {}), + pytest.param([], {'a': 1., 'b': 2.}), + pytest.param([1., 2.], {'a': 1., 'b': 2.})]) +def test_apply_to_inputs(args, kwargs): + def apply_fn(inputs, factor): + if isinstance(inputs, (float, int)): + return inputs * factor + elif isinstance(inputs, dict): + return {k: apply_fn(v, factor) for k, v in inputs.items()} + elif isinstance(inputs, (tuple, list)): + return [apply_fn(x, factor) for x in inputs] + + @_apply_to_inputs(apply_fn, factor=2.) + def test_fn(*func_args, **func_kwargs): + return func_args, func_kwargs + + result_args, result_kwargs = test_fn(*args, **kwargs) + assert isinstance(result_args, (list, tuple)) + assert isinstance(result_kwargs, dict) + assert len(result_args) == len(args) + assert len(result_kwargs) == len(kwargs) + assert all([k in result_kwargs for k in kwargs.keys()]) + for arg, result_arg in zip(args, result_args): + assert arg * 2. == result_arg + + for key in kwargs.keys(): + arg = kwargs[key] + result_arg = result_kwargs[key] + assert arg * 2. == result_arg + + +def test_apply_to_outputs(): + def apply_fn(inputs, additional_str): + return str(inputs) + additional_str + + @_apply_to_outputs(apply_fn, additional_str='_str') + def test_fn(*args, **kwargs): + return 'dummy' + + assert test_fn() == 'dummy_str' + + +def test_convert_to_tensor(): + for test_item in [1., np.array([1.])]: + result_tensor = _convert_to_tensor(test_item) + assert isinstance(result_tensor, torch.Tensor) + assert result_tensor.item() == 1. + + +def test_convert_to_numpy(): + for test_item in [1., torch.tensor([1.])]: + result = _convert_to_numpy(test_item) + assert isinstance(result, np.ndarray) + assert result.item() == 1. + + +def test_numpy_metric_conversion(): + @_numpy_metric_conversion + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def test_tensor_metric_conversion(): + @_tensor_metric_conversion + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def setup_ddp(rank, worldsize, ): + import os + + os.environ['MASTER_ADDR'] = 'localhost' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def ddp_test_fn(rank, worldsize): + setup_ddp(rank, worldsize) + tensor = torch.tensor([1.], device='cuda:0') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert reduced_tensor.item() == dist.get_world_size(), \ + 'Sync-Reduce does not work properly with DDP and Tensors' + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_sync_reduce_ddp(): + """Make sure sync-reduce works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize) + + +def test_sync_reduce_simple(): + """Make sure sync-reduce works without DDP""" + tensor = torch.tensor([1.], device='cpu') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert torch.allclose(tensor, reduced_tensor), \ + 'Sync-Reduce does not work properly without DDP and Tensors' + + +def _test_tensor_metric(is_ddp: bool): + @tensor_metric() + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_tensor_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_tensor_metric(True) + + +def test_tensor_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + + world_size = 2 + mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) + + +def test_tensor_metric_simple(): + _test_tensor_metric(False) + + +def _test_numpy_metric(is_ddp: bool): + @numpy_metric() + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_numpy_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_numpy_metric(True) + + +def test_numpy_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + world_size = 2 + mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) + + +def test_numpy_metric_simple(): + _test_tensor_metric(False) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000000..e83a9d97b7a6c --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,85 @@ +import numpy as np +import torch + +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric + + +class DummyTensorMetric(TensorMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, torch.Tensor) + assert isinstance(input2, torch.Tensor) + return 1. + + +class DummyNumpyMetric(NumpyMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, np.ndarray) + assert isinstance(input2, np.ndarray) + return 1. + + +def _test_metric(metric: Metric): + input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + + def change_and_check_device_dtype(device, dtype): + metric.to(device=device, dtype=dtype) + + metric_val = metric(input1, input2) + assert isinstance(metric_val, torch.Tensor) + + if device is not None: + assert metric.device in [device, torch.device(device)] + assert metric_val.device in [device, torch.device(device)] + + if dtype is not None: + assert metric.dtype == dtype + assert metric_val.dtype == dtype + + devices = [None, 'cpu'] + if torch.cuda.is_available(): + devices += ['cuda:0'] + + for device in devices: + for dtype in [None, torch.float32, torch.float64]: + change_and_check_device_dtype(device=device, dtype=dtype) + + if torch.cuda.is_available(): + metric.cuda(0) + assert metric.device == torch.device('cuda', index=0) + assert metric(input1, input2).device == torch.device('cuda', index=0) + + metric.cpu() + assert metric.device == torch.device('cpu') + assert metric(input1, input2).device == torch.device('cpu') + + metric.type(torch.int8) + assert metric.dtype == torch.int8 + assert metric(input1, input2).dtype == torch.int8 + + metric.float() + assert metric.dtype == torch.float32 + assert metric(input1, input2).dtype == torch.float32 + + metric.double() + assert metric.dtype == torch.float64 + assert metric(input1, input2).dtype == torch.float64 + + if torch.cuda.is_available(): + metric.cuda() + metric.half() + assert metric.dtype == torch.float16 + assert metric(input1, input2).dtype == torch.float16 + + +def test_tensor_metric(): + _test_metric(DummyTensorMetric()) + + +def test_numpy_metric(): + _test_metric(DummyNumpyMetric()) diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py new file mode 100644 index 0000000000000..dce1e56e2b332 --- /dev/null +++ b/tests/utilities/test_apply_func.py @@ -0,0 +1,66 @@ +import numbers +from collections import namedtuple + +import numpy as np +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def test_recursive_application_to_collection(): + ntc = namedtuple('Foo', ['bar']) + + to_reduce = { + 'a': torch.tensor([1.]), # Tensor + 'b': [torch.tensor([2.])], # list + 'c': (torch.tensor([100.]),), # tuple + 'd': ntc(bar=5.), # named tuple + 'e': np.array([10.]), # numpy array + 'f': 'this_is_a_dummy_str', # string + 'g': 12. # number + } + + expected_result = { + 'a': torch.tensor([2.]), + 'b': [torch.tensor([4.])], + 'c': (torch.tensor([200.]),), + 'd': ntc(bar=torch.tensor([10.])), + 'e': np.array([20.]), + 'f': 'this_is_a_dummy_str', + 'g': 24. + } + + reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), + lambda x: x * 2) + + assert isinstance(reduced, dict), ' Type Consistency of dict not preserved' + assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved' + assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \ + 'At least one type was not correctly preserved' + + assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor' + assert torch.allclose(expected_result['a'], reduced['a']), \ + 'Reduction of a tensor does not yield the expected value' + + assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list' + assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \ + 'At least one value of list reduction did not come out as expected' + + assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple' + assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \ + 'At least one value of tuple reduction did not come out as expected' + + assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given' + assert isinstance(reduced['d'].bar, numbers.Number), \ + 'Failure in type promotion while reducing fields of named tuples' + assert reduced['d'].bar == expected_result['d'].bar + + assert isinstance(reduced['e'], np.ndarray), 'Type Promotion in reduction of numpy arrays failed' + assert reduced['e'] == expected_result['e'], \ + 'Reduction of numpy array did not yield the expected result' + + assert isinstance(reduced['f'], str), 'A string should not be reduced' + assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' + + assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' + assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' From 24b1e218e94abc77f3c1e80f9c7ae3d672867017 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 21:50:21 +0200 Subject: [PATCH 02/15] move dtype device mixin to more general place --- pytorch_lightning/core/lightning.py | 2 +- .../{core/properties.py => utilities/device_dtype_mixin.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename pytorch_lightning/{core/properties.py => utilities/device_dtype_mixin.py} (100%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 784ae9c3a45fa..662c05a29338d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -18,7 +18,7 @@ from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams -from pytorch_lightning.core.properties import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/core/properties.py b/pytorch_lightning/utilities/device_dtype_mixin.py similarity index 100% rename from pytorch_lightning/core/properties.py rename to pytorch_lightning/utilities/device_dtype_mixin.py From 3598db9558a1ad9e3b401f7f036f33073bb2eb32 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 21:54:13 +0200 Subject: [PATCH 03/15] refactor to general device dtype mixin --- pytorch_lightning/metrics/metric.py | 161 +--------------------------- 1 file changed, 2 insertions(+), 159 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 50853105f94f9..799a60ca215d4 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -6,11 +6,12 @@ from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin __all__ = ['Metric', 'TensorMetric', 'NumpyMetric'] -class Metric(torch.nn.Module, ABC): +class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): """ Abstract Base Class for metric implementation. @@ -29,24 +30,6 @@ def __init__(self, name: str): self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') - @property - def dtype(self) -> Union[str, torch.dtype]: - return self._dtype - - @dtype.setter - def dtype(self, new_dtype: Union[str, torch.dtype]): - # necessary to avoid infinite recursion - raise RuntimeError('Cannot set the dtype explicitly. Please use metric.to(new_dtype).') - - @property - def device(self) -> Union[str, torch.device]: - return self._device - - @device.setter - def device(self, new_device: Union[str, torch.device]): - # Necessary to avoid infinite recursion - raise RuntimeError('Cannot set the device explicitly. Please use metric.to(new_device).') - @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: """ @@ -58,146 +41,6 @@ def forward(self, *args, **kwargs) -> torch.Tensor: """ raise NotImplementedError - def to(self, *args, **kwargs) -> torch.nn.Module: - """Moves and/or casts the parameters and buffers. - - This can be called as - - .. function:: to(device=None, dtype=None, non_blocking=False) - - .. function:: to(dtype, non_blocking=False) - - .. function:: to(tensor, non_blocking=False) - - Its signature is similar to :meth:`torch.Tensor.to`, but only accepts - floating point desired :attr:`dtype` s. In addition, this method will - only cast the floating point parameters and buffers to :attr:`dtype` - (if given). The integral parameters and buffers will be moved - :attr:`device`, if that is given, but with dtypes unchanged. When - :attr:`non_blocking` is set, it tries to convert/move asynchronously - with respect to the host if possible, e.g., moving CPU Tensors with - pinned memory to CUDA devices. - - See below for examples. - - Note: - This method modifies the module in-place. - - Args: - device: the desired device of the parameters - and buffers in this module - dtype: the desired floating point type of - the floating point parameters and buffers in this module - tensor: Tensor whose dtype and device are the desired - dtype and device for all parameters and buffers in this module - - Returns: - Module: self - - Example:: - >>> class ExampleMetric(Metric): - ... def __init__(self, weight: torch.Tensor): - ... super().__init__('example') - ... self.register_buffer('weight', weight) - ... def forward(self, pred, target) -> torch.Tensor: - ... return (pred - target) * self.weight - >>> _ = torch.manual_seed(0) - >>> metric = ExampleMetric(torch.rand(3, 4)) - >>> metric.weight - tensor([[0.4963, 0.7682, 0.0885, 0.1320], - [0.3074, 0.6341, 0.4901, 0.8964], - [0.4556, 0.6323, 0.3489, 0.4017]]) - >>> metric.to(torch.double) #doctest: +ELLIPSIS - ExampleMetric() - >>> metric.weight - tensor([[...]], dtype=torch.float64) - >>> cpu = torch.device('cpu') - >>> metric.to(cpu, dtype=torch.half, non_blocking=True) - ExampleMetric() - >>> metric.weight #doctest: +ELLIPSIS - tensor([[...]], dtype=torch.float16) - >>> metric.to(cpu) - ExampleMetric() - >>> metric.weight #doctest: +ELLIPSIS - tensor([[...]], dtype=torch.float16) - - - """ - device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) - if device is not None: - self._device = device - - if dtype is not None: - self._dtype = dtype - - return super().to(*args, **kwargs) - - def cuda(self, device: Optional[int] = None) -> torch.nn.Module: - """Moves all model parameters and buffers to the GPU. - - This also makes associated parameters and buffers different objects. So - it should be called before constructing optimizer if the module will - live on GPU while being optimized. - - Arguments: - device (int, optional): if specified, all parameters will be - copied to that device - - Returns: - Module: - """ - - self._device = torch.device('cuda', index=device) - return super().cuda(device=device) - - def cpu(self) -> torch.nn.Module: - """Moves all model parameters and buffers to the CPU. - - Returns: - Module: self - """ - self._device = torch.device('cpu') - return super().cpu() - - def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: - """Casts all parameters and buffers to :attr:`dst_type`. - - Arguments: - dst_type (type or string): the desired type - - Returns: - Module: self - """ - self._dtype = dst_type - return super().type(dst_type=dst_type) - - def float(self) -> torch.nn.Module: - """Casts all floating point parameters and buffers to float datatype. - - Returns: - Module: self - """ - self._dtype = torch.float - return super().float() - - def double(self) -> torch.nn.Module: - """Casts all floating point parameters and buffers to ``double`` datatype. - - Returns: - Module: self - """ - self._dtype = torch.double - return super().double() - - def half(self) -> torch.nn.Module: - """Casts all floating point parameters and buffers to ``half`` datatype. - - Returns: - Module: self - """ - self._dtype = torch.half - return super().half() - class TensorMetric(Metric): """ From 3ee3fe99f6293869f28c87cdd84254585cc20e7a Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 22:06:35 +0200 Subject: [PATCH 04/15] add initial metric package description --- pytorch_lightning/metrics/__init__.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 18522e0dda94b..f587a6ca3d798 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,5 +1,24 @@ """ Metrics ======= -TODO + +Metrics are generally used to monitor model performance. + +The following package aism to provide the most convenient ones as well +as a structure to implement your custom metrics for all the fancy research +you want to do. + +For native PyTorch implementations of metrics, it is recommended to use +the :class:`TensorMetric` which handles automatted DDP syncing and conversions +to tensors for all inputs and outputs. + +If your metrics implementation works on numpy, just use the +:class:`NumpyMetric`, which handles the automatted conversion of +inputs to and outputs from numpy as well as automatted ddp syncing. + +.. warning:: Employing numpy in your metric calculation might slow + down your training substantially, since every metric computation + requires a GPU sync to convert tensors to numpy. + + """ From cf9bf784372b2dfdb412167e6b659c3882a87dd5 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 22:18:09 +0200 Subject: [PATCH 05/15] change default to none for mac os --- pytorch_lightning/metrics/converters.py | 19 +++++++++++++------ pytorch_lightning/metrics/metric.py | 8 ++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 8162876fc3b00..ff5cacd4d7793 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -4,6 +4,7 @@ sync tensors between different processes in a DDP scenario, when needed. """ +import sys import numbers from typing import Union, Any, Callable @@ -145,8 +146,8 @@ def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: def _sync_ddp_if_available(result: Union[torch.Tensor], - group: Any = torch.distributed.group.WORLD, - reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM, + group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -162,6 +163,12 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], """ if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + if reduce_op is None: + reduce_op = torch.distributed.ReduceOp.SUM + # sync all processes before reduction torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=reduce_op, group=group, @@ -170,8 +177,8 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], return result -def numpy_metric(group: Any = torch.distributed.group.WORLD, - reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: +def numpy_metric(group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on numpy arrays. @@ -197,8 +204,8 @@ def decorator_fn(func_to_decorate): return decorator_fn -def tensor_metric(group: Any = torch.distributed.group.WORLD, - reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: +def tensor_metric(group: Optional[Any] = None, + reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors. diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 799a60ca215d4..8538af869782c 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -49,8 +49,8 @@ class TensorMetric(Metric): Already handles DDP sync and input/output conversions. """ def __init__(self, name: str, - reduce_group: Optional[Any] = torch.distributed.group.WORLD, - reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + reduce_group: Optional[Any] = None, + reduce_op: Optional[Any] = None): """ Args: @@ -80,8 +80,8 @@ class NumpyMetric(Metric): Already handles DDP sync and input/output conversions. """ def __init__(self, name: str, - reduce_group: Optional[Any] = torch.distributed.group.WORLD, - reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + reduce_group: Optional[Any] = None, + reduce_op: Optional[Any] = None): """ Args: From af6f598f1b4e6cf32c53d0ca8ee88c3f289bf600 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 22:19:38 +0200 Subject: [PATCH 06/15] pep8 --- pytorch_lightning/metrics/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index ff5cacd4d7793..ad757a7dfb574 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -168,7 +168,7 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], if reduce_op is None: reduce_op = torch.distributed.ReduceOp.SUM - + # sync all processes before reduction torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=reduce_op, group=group, From cd772fde20505f6195a70b8d3dd4a097203701f4 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 18 May 2020 22:22:59 +0200 Subject: [PATCH 07/15] fix import --- pytorch_lightning/metrics/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index ad757a7dfb574..9d8b8303eed7f 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -6,7 +6,7 @@ import sys import numbers -from typing import Union, Any, Callable +from typing import Union, Any, Callable, Optional import numpy as np import torch From 5859bd6aa0ae3bf28bc0d5d205770b88a5ed4555 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 18 May 2020 20:37:32 -0400 Subject: [PATCH 08/15] Update index.rst --- docs/source/index.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index a088065401c48..f7bee4395b637 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -113,11 +113,11 @@ Indices and tables .. toctree:: :hidden: - pytorch_lightning.core - pytorch_lightning.callbacks - pytorch_lightning.loggers - pytorch_lightning.metrics - pytorch_lightning.overrides - pytorch_lightning.profiler - pytorch_lightning.trainer - pytorch_lightning.utilities + api/pytorch_lightning.core + api/pytorch_lightning.callbacks + api/pytorch_lightning.loggers + api/pytorch_lightning.metrics + api/pytorch_lightning.overrides + api/pytorch_lightning.profiler + api/pytorch_lightning.trainer + api/pytorch_lightning.utilities From 65d8a05989c14ba5e281fd5238d1af1f7a3ffdad Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 19 May 2020 07:36:50 +0200 Subject: [PATCH 09/15] Update ci-testing.yml --- .github/workflows/ci-testing.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ac24dcee0a1e1..d0ba2caa69d22 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -54,6 +54,11 @@ jobs: run: | python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)" + # versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996) + - name: Setup MacOS Minimal + if: runner.os == 'macOS' && matrix.requires ='minimal' + run : | + python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch>=1.4') ; open('requirements.txt', 'w').write(req)" - name: Set min. dependencies if: matrix.requires == 'minimal' run: | @@ -137,4 +142,4 @@ jobs: - name: Statistics if: success() run: | - coverage report \ No newline at end of file + coverage report From 781acf34fefe61fd3699b7f90e8a1722b4cfa20f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 19 May 2020 12:19:10 +0200 Subject: [PATCH 10/15] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/metrics/__init__.py | 8 ++++---- pytorch_lightning/metrics/converters.py | 12 ++++++------ pytorch_lightning/metrics/metric.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index f587a6ca3d798..cd721851307df 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -4,17 +4,17 @@ Metrics are generally used to monitor model performance. -The following package aism to provide the most convenient ones as well +The following package aims to provide the most convenient ones as well as a structure to implement your custom metrics for all the fancy research you want to do. For native PyTorch implementations of metrics, it is recommended to use -the :class:`TensorMetric` which handles automatted DDP syncing and conversions +the :class:`TensorMetric` which handles automated DDP syncing and conversions to tensors for all inputs and outputs. If your metrics implementation works on numpy, just use the -:class:`NumpyMetric`, which handles the automatted conversion of -inputs to and outputs from numpy as well as automatted ddp syncing. +:class:`NumpyMetric`, which handles the automated conversion of +inputs to and outputs from numpy as well as automated ddp syncing. .. warning:: Employing numpy in your metric calculation might slow down your training substantially, since every metric computation diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 9d8b8303eed7f..99315e5af3bfd 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -1,6 +1,6 @@ """ This file provides functions and decorators for automated input and output -conversion to/from numpy.ndarray and torch.Tensor as well as utilities to +conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to sync tensors between different processes in a DDP scenario, when needed. """ @@ -105,9 +105,9 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: """ - Decorator Handling the argument conversion for metrics working on numpy. + Decorator handling the argument conversion for metrics working on numpy. All inputs of the decorated function will be converted to numpy and all - outputs will be converted to Tensors + outputs will be converted to tensors. Args: func_to_decorate: the function whose inputs and outputs shall be converted @@ -155,7 +155,7 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum + reduce_op: the reduction operation. Defaults to :func:`torch.sum`. Returns: reduced value @@ -184,7 +184,7 @@ def numpy_metric(group: Optional[Any] = None, It handles the argument conversion and DDP reduction for metrics working on numpy. All inputs of the decorated function will be converted to numpy and all - outputs will be converted to Tensors. + outputs will be converted to tensors. In DDP Training all output tensors will be reduced according to the given rules. Args: @@ -210,7 +210,7 @@ def tensor_metric(group: Optional[Any] = None, This decorator shall be used on all function metrics working on tensors. It handles the argument conversion and DDP reduction for metrics working on tensors. - All inputs and outputs of the decorated function will be converted to tensors . + All inputs and outputs of the decorated function will be converted to tensors. In DDP Training all output tensors will be reduced according to the given rules. Args: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 8538af869782c..2bdd520c91645 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,7 +13,7 @@ class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): """ - Abstract Base Class for metric implementation. + Abstract base class for metric implementation. Should be used to implement metrics that 1. Return multiple Outputs From 3784651da133ae9237818816b3a3cc90c3dc1575 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 19 May 2020 12:24:27 +0200 Subject: [PATCH 11/15] Update CHANGELOG.md --- CHANGELOG.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c15da063230e..cd9389c6d7078 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## Metrics (will be added to unreleased once the metric branch was finished) -- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326)) - ## [unreleased] - YYYY-MM-DD ### Added +- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) + - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). ### Changed From f9c5a7590736dd0345505058ad1cbaca2f729f61 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 19 May 2020 12:29:27 +0200 Subject: [PATCH 12/15] Update pytorch_lightning/metrics/converters.py --- pytorch_lightning/metrics/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 99315e5af3bfd..f50f54f1d82d9 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -81,7 +81,7 @@ def _convert_to_tensor(data: Any) -> Any: elif isinstance(data, torch.Tensor): return data - raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__) + raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: From c057a2477d48ee97979a3d48f780829ee5b7bb48 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 19 May 2020 14:24:32 +0200 Subject: [PATCH 13/15] readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 02ca78f68dcbf..277396eca5d3a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ removed until codecov badge isn't empy. likely a config error showing nothing on | Linux py3.6 [CPU] | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | | Linux py3.7 [GPU] | - | - | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) | | Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| OSX py3.6 / py3.7 / py3.8| [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | +| OSX py3.6 / py3.7 / py3.8| - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | | Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | From d4542b2547090d7afb1e4603ccb8e8182596ebf4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 19 May 2020 14:40:10 +0200 Subject: [PATCH 14/15] Update metric.py --- pytorch_lightning/metrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2bdd520c91645..5247084498559 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -66,7 +66,7 @@ def __init__(self, name: str, def __call__(self, *args, **kwargs) -> torch.Tensor: def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype) + return x.to(device=self.device, dtype=self.dtype, non_blocking=True) return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, _to_device_dtype) @@ -97,7 +97,7 @@ def __init__(self, name: str, def __call__(self, *args, **kwargs) -> torch.Tensor: def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype) + return x.to(device=self.device, dtype=self.dtype, non_blocking=True) return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, _to_device_dtype) From aabdc408d355ef4abbe83779c85decc4f3915c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 May 2020 16:43:05 +0200 Subject: [PATCH 15/15] Update pytorch_lightning/metrics/converters.py --- pytorch_lightning/metrics/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index f50f54f1d82d9..7681be589f7fe 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -155,7 +155,7 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to :func:`torch.sum`. + reduce_op: the reduction operation. Defaults to sum. Returns: reduced value