Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New modular metric interface #2528

Merged
merged 42 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0ba7d63
new base structure
Jun 26, 2020
6481b6b
missing packages
Jun 26, 2020
f6a0a4d
updated interface
Jun 30, 2020
9368ac6
revert some changes
Jul 6, 2020
5d76528
fixes
Jul 6, 2020
6709952
Merge branch 'master' into new_metric_interface
SkafteNicki Jul 6, 2020
4958cb2
add changelog
Jul 6, 2020
0a3849e
fix bug
Jul 6, 2020
883efa9
added description
Jul 6, 2020
451d4b6
'merge'
SkafteNicki Aug 5, 2020
ea0dfc6
merge
SkafteNicki Aug 5, 2020
cdf2dbd
merge
SkafteNicki Aug 7, 2020
d99821e
test for pickable
SkafteNicki Aug 7, 2020
e8a6a7b
fixing test
SkafteNicki Aug 7, 2020
9626b22
fixing test
SkafteNicki Aug 7, 2020
1f458f6
fix pickle issue
SkafteNicki Aug 7, 2020
bdf9364
'mergeø
SkafteNicki Aug 10, 2020
be70ac8
reduceop typehints back
SkafteNicki Aug 10, 2020
fd6e719
remove redundant module arg
SkafteNicki Aug 11, 2020
9d16c69
add save/load test
SkafteNicki Aug 11, 2020
810d3dd
add aggregate method
SkafteNicki Aug 11, 2020
4923e34
text clarification
SkafteNicki Aug 11, 2020
ae60d3d
fix doctest
SkafteNicki Aug 11, 2020
a30a033
merge
SkafteNicki Aug 13, 2020
ffb4ed7
merge
SkafteNicki Aug 18, 2020
796d913
Apply suggestions from code review
awaelchli Aug 22, 2020
f14ec8f
Merge branch 'master' into new_metric_interface
awaelchli Aug 22, 2020
8a3a128
change test to results obj
SkafteNicki Aug 24, 2020
26b4bb4
fix docs
SkafteNicki Aug 24, 2020
8289a41
merge
SkafteNicki Aug 24, 2020
8aae0be
formatting
Borda Aug 25, 2020
b783ec5
formatting
rohitgr7 Aug 25, 2020
e0ba557
formatting
rohitgr7 Aug 25, 2020
da4e94d
formatting
rohitgr7 Aug 25, 2020
a27aebd
formatting
rohitgr7 Aug 25, 2020
9641f84
formatting
rohitgr7 Aug 25, 2020
9fabc89
pep
rohitgr7 Aug 25, 2020
619ffc2
Update CHANGELOG.md
Borda Aug 25, 2020
b03ce8f
suggestions
Aug 26, 2020
ac169df
fix tests
Aug 26, 2020
c1a1389
fix pep8
Aug 26, 2020
847d0ed
fix tests
Aug 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,32 @@ def new_func(*args, **kwargs):
return decorator_fn


def _convert_to_tensor(data: Any) -> Any:
def convert_to_tensor(data: Any, dtype=None, device=None) -> Any:
"""
Maps all kind of collections and numbers to tensors.

Args:
data: the data to convert to tensor

dtype: data type to convert to

device: device to cast to
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

Return:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
return torch.tensor([data], dtype=dtype, device=device)
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
return torch.from_numpy(data).to(device=device, dtype=dtype)
elif isinstance(data, torch.Tensor):
return data
return data.to(device=device, dtype=dtype)

raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.

Args:
Expand Down Expand Up @@ -114,7 +118,7 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)


def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -127,7 +131,7 @@ def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
return _apply_to_outputs(convert_to_tensor)(func_to_decorate)


def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -161,7 +165,7 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)


def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -175,7 +179,7 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
_convert_to_tensor)(func_to_decorate)
convert_to_tensor)(func_to_decorate)


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -215,10 +219,11 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
return _tensor_collection_metric_output_conversion(func_convert_inputs)


def _sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None,
) -> torch.Tensor:
def sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Any] = None,
ddp_normalize=False,
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Expand All @@ -243,11 +248,14 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

if ddp_normalize:
result / torch.distributed.get_world_size(group)

return result


def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
reduce_op: Optional[Any] = None) -> Callable:
"""
This decorator syncs a functions outputs across different processes for DDP.

Expand All @@ -262,14 +270,14 @@ def sync_ddp(group: Optional[Any] = None,

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
_sync_ddp_if_available, group=group,
sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)

return decorator_fn


def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
reduce_op: Optional[Any] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
Expand All @@ -292,7 +300,7 @@ def decorator_fn(func_to_decorate):


def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
reduce_op: Optional[Any] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
Expand All @@ -314,7 +322,7 @@ def decorator_fn(func_to_decorate):


def tensor_collection_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
reduce_op: Optional[Any] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors and returning collections
that cannot be converted to tensors.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def confusion_matrix(
"""
num_classes = get_num_classes(pred, target, None)

unique_labels = target.view(-1) * num_classes + pred.view(-1)
unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int)

bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
cm = bins.reshape(num_classes, num_classes).squeeze().float()
Expand Down
131 changes: 103 additions & 28 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
import numbers

import torch
import torch.distributed
import numpy as np

from pytorch_lightning.metrics.converters import (
tensor_metric, numpy_metric, tensor_collection_metric)
sync_ddp_if_available, convert_to_tensor, convert_to_numpy)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin

Expand All @@ -17,6 +19,16 @@ class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
Should be used to implement metrics that
1. Return multiple Outputs
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
2. Handle their own DDP sync

Metric hooks that can be implemented are:
input_convert: pre-forward hook that takes care of input conversion
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
output_convert: post-forward hook that takes care of output convertion
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
ddp_sync: implementation of ddp sync
compute: post-ddp sync for additional metric computations

Call order:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
input_convert -> forward -> output_convert -> ddp_sync -> compute
justusschock marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(self, name: str):
Expand All @@ -29,18 +41,51 @@ def __init__(self, name: str):
self.name = name
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')
self.register_forward_pre_hook(self.input_convert)
self.register_forward_hook(self.output_convert)
self.register_forward_hook(self.ddp_sync)
self.register_forward_hook(self.compute)

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Implements the actual metric computation.

Returns:
metric value
metric value or metric state

"""
raise NotImplementedError

def compute(self, module, input, output) -> torch.Tensor:
"""
Implement additionally metric computations to be done after the ddp sync

Args:
module: current metric module

input: input to forward method

output: output from forward method
justusschock marked this conversation as resolved.
Show resolved Hide resolved

Returns:
final metric value

"""
return output

def ddp_sync(self, module, input, output):
"""

"""
return output

def input_convert(self, module, input):
return input

def output_convert(self, module, input, output):
return output

justusschock marked this conversation as resolved.
Show resolved Hide resolved

class TensorMetric(Metric):
"""
Expand All @@ -51,7 +96,8 @@ class TensorMetric(Metric):

def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
reduce_op: Optional[Any] = None,
ddp_normalize: bool = False):
"""

Args:
Expand All @@ -62,15 +108,23 @@ def __init__(self, name: str,
Defaults to sum.
"""
super().__init__(name)
self._orig_call = tensor_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
self.ddp_normalize = ddp_normalize

def input_convert(self, module, input):
return apply_to_collection(input,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)

def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
def output_convert(self, module, input, output):
return apply_to_collection(output, torch.Tensor, convert_to_tensor,
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.dtype, self.device)

return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
def ddp_sync(self, module, input, output):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op, self.ddp_normalize)


class TensorCollectionMetric(Metric):
Expand All @@ -92,7 +146,8 @@ class TensorCollectionMetric(Metric):

def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
reduce_op: Optional[Any] = None,
ddp_normalize: bool = False):
"""

Args:
Expand All @@ -103,15 +158,25 @@ def __init__(self, name: str,
Defaults to sum.
"""
super().__init__(name)
self._orig_call = tensor_collection_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
self.ddp_normalize = ddp_normalize

def input_convert(self, module, input):
return apply_to_collection(input,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)

def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
def output_convert(self, module, input, output):
return apply_to_collection(output,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)

return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
def ddp_sync(self, module, input, output):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op, self.ddp_normalize)


class NumpyMetric(Metric):
Expand All @@ -124,7 +189,8 @@ class NumpyMetric(Metric):

def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
reduce_op: Optional[Any] = None,
ddp_normalize: bool = False):
"""

Args:
Expand All @@ -135,12 +201,21 @@ def __init__(self, name: str,
Defaults to sum.
"""
super().__init__(name)
self._orig_call = numpy_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)

def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)

return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
self.ddp_normalize = ddp_normalize

def input_convert(self, module, input):
return apply_to_collection(input,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_numpy)

def output_convert(self, module, input, output):
return apply_to_collection(output,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)

def ddp_sync(self, module, input, output):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op, self.ddp_normalize)
1 change: 0 additions & 1 deletion tests/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_confusion_matrix(normalize):

target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()

cm = conf_matrix(pred, target)
assert isinstance(cm, torch.Tensor)

Expand Down
Loading