Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric classes (#1326) #1877

Merged
merged 16 commits into from
May 19, 2020
Merged
7 changes: 6 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ jobs:
run: |
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)"

# versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996)
- name: Setup MacOS Minimal
Borda marked this conversation as resolved.
Show resolved Hide resolved
if: runner.os == 'macOS' && matrix.requires ='minimal'
run : |
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch>=1.4') ; open('requirements.txt', 'w').write(req)"
- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
Expand Down Expand Up @@ -137,4 +142,4 @@ jobs:
- name: Statistics
if: success()
run: |
coverage report
coverage report
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Metrics (will be added to unreleased once the metric branch was finished)
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326))

## [unreleased] - YYYY-MM-DD

### Added
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ PyTorch Lightning Documentation
hooks
lightning-module
loggers
metrics
trainer

.. toctree::
Expand Down Expand Up @@ -115,6 +116,7 @@ Indices and tables
api/pytorch_lightning.core
api/pytorch_lightning.callbacks
api/pytorch_lightning.loggers
api/pytorch_lightning.metrics
api/pytorch_lightning.overrides
api/pytorch_lightning.profiler
api/pytorch_lightning.trainer
Expand Down
4 changes: 4 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: pytorch_lightning.metrics
:members:
:noindex:
:exclude-members:
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Metrics
=======

Metrics are generally used to monitor model performance.

The following package aism to provide the most convenient ones as well
Borda marked this conversation as resolved.
Show resolved Hide resolved
as a structure to implement your custom metrics for all the fancy research
you want to do.

For native PyTorch implementations of metrics, it is recommended to use
the :class:`TensorMetric` which handles automatted DDP syncing and conversions
Borda marked this conversation as resolved.
Show resolved Hide resolved
to tensors for all inputs and outputs.

If your metrics implementation works on numpy, just use the
:class:`NumpyMetric`, which handles the automatted conversion of
Borda marked this conversation as resolved.
Show resolved Hide resolved
inputs to and outputs from numpy as well as automatted ddp syncing.
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. warning:: Employing numpy in your metric calculation might slow
down your training substantially, since every metric computation
requires a GPU sync to convert tensors to numpy.


"""
230 changes: 230 additions & 0 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""
This file provides functions and decorators for automated input and output
conversion to/from numpy.ndarray and torch.Tensor as well as utilities to
Borda marked this conversation as resolved.
Show resolved Hide resolved
sync tensors between different processes in a DDP scenario, when needed.
"""

import sys
import numbers
from typing import Union, Any, Callable, Optional

import numpy as np
import torch
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from pytorch_lightning.utilities.apply_func import apply_to_collection


def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all inputs of a function.
Args:
func_to_apply: the function to apply to the inputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied

Returns:
the decorated function
"""

def decorator_fn(func_to_decorate):
# actual function applying the give function to inputs
def new_func(*args, **kwargs):
args = func_to_apply(args, *dec_args, **dec_kwargs)
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
return func_to_decorate(*args, **kwargs)

return new_func

return decorator_fn


def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all outputs of a function.
Args:
func_to_apply: the function to apply to the outputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied

Returns:
the decorated function
"""

def decorator_fn(function_to_decorate):
# actual function applying the give function to outputs
def new_func(*args, **kwargs):
result = function_to_decorate(*args, **kwargs)
return func_to_apply(result, *dec_args, **dec_kwargs)

return new_func

return decorator_fn


def _convert_to_tensor(data: Any) -> Any:
"""
Maps all kind of collections and numbers to tensors.

Args:
data: the data to convert to tensor

Returns:
the converted data

"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
elif isinstance(data, torch.Tensor):
return data

raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.
Args:
data: the tensor or array to convert to numpy

Returns:
the resulting numpy array

"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
elif isinstance(data, np.ndarray):
return data

raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)


def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on numpy.
Borda marked this conversation as resolved.
Show resolved Hide resolved
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to Tensors
Borda marked this conversation as resolved.
Show resolved Hide resolved

Args:
func_to_decorate: the function whose inputs and outputs shall be converted

Returns:
the decorated function

"""
# applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
return func_convert_in_out


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors

Args:
func_to_decorate: the function whose inputs and outputs shall be converted

Returns:
the decorated function

"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
# convert all outputs to tensor if possible
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)


def _sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None,
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reduce_op: the reduction operation. Defaults to sum
reduce_op: the reduction operation. Defaults to :func:`torch.sum`.

not sure if it is torch.sum, but probably?
also appears in other places

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this is a special sync operator, but probably behaves similar to torch.sum

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, now I see it. Jirka already commited my suggestion, so I will make another one so you can revert it, ok? It's only in this one place.


Returns:
reduced value

"""

if torch.distributed.is_available() and torch.distributed.is_initialized():
if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

return result


def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.

It handles the argument conversion and DDP reduction for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to Tensors.
Borda marked this conversation as resolved.
Show resolved Hide resolved
In DDP Training all output tensors will be reduced according to the given rules.

Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
the decorated function

"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))

return decorator_fn


def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.

It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors .
Borda marked this conversation as resolved.
Show resolved Hide resolved
In DDP Training all output tensors will be reduced according to the given rules.

Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
the decorated function

"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))

return decorator_fn
Loading