Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for native torch metrics #1962

Merged
merged 13 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Metrics (will be added to unreleased once the metric branch was finished)
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326))

## [unreleased] - YYYY-MM-DD

### Added
Expand Down
Binary file not shown.
Binary file not shown.
46 changes: 20 additions & 26 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all inputs of a function.

Args:
func_to_apply: the function to apply to the inputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied

Returns:
Return:
the decorated function
"""

Expand All @@ -42,12 +43,13 @@ def new_func(*args, **kwargs):
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all outputs of a function.

Args:
func_to_apply: the function to apply to the outputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied

Returns:
Return:
the decorated function
"""

Expand All @@ -64,40 +66,40 @@ def new_func(*args, **kwargs):

def _convert_to_tensor(data: Any) -> Any:
"""
Maps all kind of collections and numbers to tensors
Maps all kind of collections and numbers to tensors.

Args:
data: the data to convert to tensor

Returns:
Return:
the converted data

"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)

raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__)
elif isinstance(data, torch.Tensor):
return data

raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""
converts all tensors and numpy arrays to numpy arrays
"""Convert all tensors and numpy arrays to numpy arrays.

Args:
data: the tensor or array to convert to numpy

Returns:
Return:
the resulting numpy array

"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
elif isinstance(data, np.ndarray):
return data

raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)

Expand All @@ -111,9 +113,8 @@ def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
Args:
func_to_decorate: the function whose inputs and outputs shall be converted

Returns:
Return:
the decorated function

"""
# applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
Expand All @@ -132,9 +133,8 @@ def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Args:
func_to_decorate: the function whose inputs and outputs shall be converted

Returns:
Return:
the decorated function

"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
Expand All @@ -156,9 +156,8 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.

Returns:
Return:
reduced value

"""

if torch.distributed.is_available() and torch.distributed.is_initialized():
Expand All @@ -180,7 +179,6 @@ def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.

It handles the argument conversion and DDP reduction for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to tensors.
Expand All @@ -190,13 +188,12 @@ def numpy_metric(group: Optional[Any] = None,
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
Return:
the decorated function

"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))

Expand All @@ -207,7 +204,6 @@ def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.

It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Expand All @@ -216,14 +212,12 @@ def tensor_metric(group: Optional[Any] = None,
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
Return:
the decorated function

"""


def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Sequence
from typing import Optional, Tuple, Callable
from functools import wraps

import torch

Expand Down Expand Up @@ -326,6 +327,7 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):

def auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]

Expand Down
25 changes: 10 additions & 15 deletions pytorch_lightning/metrics/functional/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@

def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
reduces a given tensor by a given reduction method
Parameters
----------
to_reduce : torch.Tensor
the tensor, which shall be reduced
reduction : str
a string specifying the reduction method.
should be one of 'elementwise_mean' | 'none' | 'sum'
Returns
-------
torch.Tensor
Reduces a given tensor by a given reduction method

Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')

Return:
reduced Tensor
Raises
------
ValueError
if an invalid reduction parameter was given

Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
Expand Down
Loading