diff --git a/CHANGELOG.md b/CHANGELOG.md index fe4ea78c0525f..27c77cafd4db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) ### Changed diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 62533a878cb51..670d369020d31 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,7 +20,7 @@ from torch import Tensor import os -from pytorch_lightning.metrics.converters import _sync_ddp_if_available +from pytorch_lightning.metrics.converters import sync_ddp_if_available class Result(Dict): @@ -124,7 +124,7 @@ def log( # sync across ddp if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): - value = _sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op) + value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op) if 'meta' not in self: self.__setitem__('meta', {}) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 9f75ec3b43c7f..a41a621c905da 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -34,7 +34,7 @@ class ReduceOp: SUM = None - rank_zero_warn('Unsupported `ReduceOp` for distributed computing.') + rank_zero_warn('Unsupported `ReduceOp` for distributed computing') def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: @@ -86,28 +86,30 @@ def new_func(*args, **kwargs): return decorator_fn -def _convert_to_tensor(data: Any) -> Any: +def convert_to_tensor(data: Any, dtype=None, device=None) -> Any: """ Maps all kind of collections and numbers to tensors. Args: data: the data to convert to tensor + dtype: data type to convert to + device: device to cast to Return: the converted data """ if isinstance(data, numbers.Number): - return torch.tensor([data]) + return torch.tensor([data], dtype=dtype, device=device) # is not array of object elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: - return torch.from_numpy(data) + return torch.from_numpy(data).to(device=device, dtype=dtype) elif isinstance(data, torch.Tensor): - return data + return data.to(device=device, dtype=dtype) raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") -def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: +def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: """Convert all tensors and numpy arrays to numpy arrays. Args: @@ -137,7 +139,7 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable: Callable: the decorated function """ return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate) def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -150,7 +152,7 @@ def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: Return: Callable: the decorated function """ - return _apply_to_outputs(_convert_to_tensor)(func_to_decorate) + return _apply_to_outputs(convert_to_tensor)(func_to_decorate) def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: @@ -184,7 +186,7 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable: Callable: the decorated function """ return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate) + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate) def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -198,7 +200,7 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C Callable: the decorated function """ return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), - _convert_to_tensor)(func_to_decorate) + convert_to_tensor)(func_to_decorate) def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: @@ -238,10 +240,10 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable return _tensor_collection_metric_output_conversion(func_convert_inputs) -def _sync_ddp_if_available(result: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None, - ) -> torch.Tensor: +def sync_ddp_if_available(result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[ReduceOp] = None + ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -278,6 +280,38 @@ def _sync_ddp_if_available(result: Union[torch.Tensor], return result +def gather_all_tensors_if_available(result: Union[torch.Tensor], + group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcasted to all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + world_size = torch.distributed.get_world_size(group) + + gathered_result = world_size * [torch.zeros_like(result)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + result = gathered_result + + return result + + def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ @@ -294,7 +328,7 @@ def sync_ddp(group: Optional[Any] = None, def decorator_fn(func_to_decorate): return _apply_to_outputs(apply_to_collection, torch.Tensor, - _sync_ddp_if_available, group=group, + sync_ddp_if_available, group=group, reduce_op=reduce_op)(func_to_decorate) return decorator_fn diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 7932483ff75aa..54060425019d2 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -304,7 +304,7 @@ def confusion_matrix( """ num_classes = get_num_classes(pred, target, None) - unique_labels = target.view(-1) * num_classes + pred.view(-1) + unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int) bins = torch.bincount(unique_labels, minlength=num_classes ** 2) cm = bins.reshape(num_classes, num_classes).squeeze().float() diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 18a3df7819a40..5f61a50e6cd25 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -14,23 +14,40 @@ from abc import ABC, abstractmethod from typing import Any, Optional +import numbers import torch -import torch.distributed +from torch import nn +import numpy as np from pytorch_lightning.metrics.converters import ( - tensor_metric, numpy_metric, tensor_collection_metric) + sync_ddp_if_available, gather_all_tensors_if_available, + convert_to_tensor, convert_to_numpy) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): +class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): """ Abstract base class for metric implementation. Should be used to implement metrics that - 1. Return multiple Outputs - 2. Handle their own DDP sync + + 1. Return multiple Outputs + 2. Handle their own DDP sync + + Metric hooks that can be implemented are + + * input_convert: pre-forward hook that takes care of input conversion + * output_convert: post-forward hook that takes care of output convertion + * ddp_sync: implementation of ddp sync, default is gather all + * aggregate: implement how values should be aggregated + * compute: post-ddp sync for additional metric computations + + Call order + + input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute + """ def __init__(self, name: str): @@ -41,18 +58,99 @@ def __init__(self, name: str): """ super().__init__() self.name = name + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') + + # Register hooks + self.register_forward_pre_hook(self.input_convert) + self.register_forward_hook(self.output_convert) + self.register_forward_hook(self.ddp_sync) + self.register_forward_hook(self.aggregate) + self.register_forward_hook(self.compute) + + @staticmethod + def input_convert(self, data: Any): + """ + Implement how the inputs should be casted before calling forward + + Args: + data: input to forward method + + Returns: + casted data + """ + return data @abstractmethod - def forward(self, *args, **kwargs) -> torch.Tensor: + def forward(self, *args, **kwargs): """ Implements the actual metric computation. Returns: - metric value + metric value or metric state """ raise NotImplementedError + @staticmethod + def output_convert(self, data: Any, output: Any): + """ + Implement how outputs from forward should be casted + + Args: + data: input to forward method + output: output from forward method + + Returns: + casted outputs + """ + return output + + @staticmethod + def ddp_sync(self, data: Any, output: Any): + """ + Implement how the outputs from forward should be synced + + Args: + data: input to forward method + output: output from the `output_convert` hook + + Returns: + synced output + + """ + return output + + @staticmethod + def aggregate(self, data: Any, output: Any): + """ + Implement aggregation of values on the same device + + Args: + data: input to forward method + output: output from the `ddp_sync` hook + + Returns: + aggregated values + + """ + return output + + @staticmethod + def compute(self, data: Any, output: Any): + """ + Implement additionally metric computations to be done after the ddp sync + + Args: + data: input to forward method + output: output from the `aggregate` hook + + Returns: + final metric value + + """ + return output + class TensorMetric(Metric): """ @@ -74,15 +172,25 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = tensor_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + + @staticmethod + def input_convert(self, data: Any): + return apply_to_collection(data, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) + @staticmethod + def output_convert(self, data: Any, output: Any): + return apply_to_collection(output, torch.Tensor, convert_to_tensor, + self.dtype, self.device) - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + @staticmethod + def ddp_sync(self, data: Any, output: Any): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) class TensorCollectionMetric(Metric): @@ -115,15 +223,27 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = tensor_collection_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + + @staticmethod + def input_convert(self, data: Any): + return apply_to_collection(data, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) + @staticmethod + def output_convert(self, data: Any, output: Any): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + @staticmethod + def ddp_sync(self, data: Any, output: Any): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) class NumpyMetric(Metric): @@ -147,12 +267,23 @@ def __init__(self, name: str, Defaults to sum. """ super().__init__(name) - self._orig_call = numpy_metric(group=reduce_group, - reduce_op=reduce_op)(super().__call__) - - def __call__(self, *args, **kwargs) -> torch.Tensor: - def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: - return x.to(device=self.device, dtype=self.dtype, non_blocking=True) - - return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, - _to_device_dtype) + self.reduce_group = reduce_group + self.reduce_op = reduce_op + + @staticmethod + def input_convert(self, data: Any): + return apply_to_collection(data, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_numpy) + + @staticmethod + def output_convert(self, data: Any, output: Any): + return apply_to_collection(output, + (torch.Tensor, np.ndarray, numbers.Number), + convert_to_tensor, + self.dtype, self.device) + + @staticmethod + def ddp_sync(self, data: Any, output: Any): + return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, + self.reduce_group, self.reduce_op) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 6d7cd365d8c25..feb08fddda467 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -175,3 +175,20 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result): setattr(result, f'{eval_name}_step_end_metric', reduced) return result + + def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None): + """Lightning calls this inside the training loop""" + # forward pass + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # call metric + val = self.metric(x, y) + + result = TrainResult(minimize=loss_val) + result.log('metric_val', val) + return result diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py index 52fffd4ad72d1..53ec95c907092 100644 --- a/tests/metrics/test_classification.py +++ b/tests/metrics/test_classification.py @@ -45,7 +45,6 @@ def test_confusion_matrix(normalize): target = (torch.arange(120) % 3).view(-1, 1) pred = target.clone() - cm = conf_matrix(pred, target) assert isinstance(cm, torch.Tensor) diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py index e37dfa9c7c5a5..1bfc635f5529b 100644 --- a/tests/metrics/test_converters.py +++ b/tests/metrics/test_converters.py @@ -10,11 +10,11 @@ from pytorch_lightning.metrics.converters import ( _apply_to_inputs, _apply_to_outputs, - _convert_to_tensor, - _convert_to_numpy, + convert_to_tensor, + convert_to_numpy, _numpy_metric_conversion, _tensor_metric_conversion, - _sync_ddp_if_available, + sync_ddp_if_available, tensor_metric, numpy_metric ) @@ -63,14 +63,14 @@ def test_fn(*args, **kwargs): def test_convert_to_tensor(): for test_item in [1., np.array([1.])]: - result_tensor = _convert_to_tensor(test_item) + result_tensor = convert_to_tensor(test_item) assert isinstance(result_tensor, torch.Tensor) assert result_tensor.item() == 1. def test_convert_to_numpy(): for test_item in [1., torch.tensor([1.])]: - result = _convert_to_numpy(test_item) + result = convert_to_numpy(test_item) assert isinstance(result, np.ndarray) assert result.item() == 1. @@ -123,12 +123,12 @@ def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False): else: tensor = torch.tensor([1.], ) if reduction_mean: - reduced_tensor = _sync_ddp_if_available(tensor, reduce_op='avg') + reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg') manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size() assert reduced_tensor.item() == manual_reduction else: - reduced_tensor = _sync_ddp_if_available(tensor) + reduced_tensor = sync_ddp_if_available(tensor) assert reduced_tensor.item() == dist.get_world_size(), \ 'Sync-Reduce does not work properly with DDP and Tensors' @@ -158,7 +158,7 @@ def test_sync_reduce_simple(): """Make sure sync-reduce works without DDP""" tensor = torch.tensor([1.], device='cpu') - reduced_tensor = _sync_ddp_if_available(tensor) + reduced_tensor = sync_ddp_if_available(tensor) assert torch.allclose(tensor, reduced_tensor), \ 'Sync-Reduce does not work properly without DDP and Tensors' diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 2176c94dfb925..5985745bfa070 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,8 +1,12 @@ +import os import numpy as np import pytest import torch +import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric +from pytorch_lightning import Trainer class DummyTensorMetric(TensorMetric): @@ -12,7 +16,7 @@ def __init__(self): def forward(self, input1, input2): assert isinstance(input1, torch.Tensor) assert isinstance(input2, torch.Tensor) - return 1. + return torch.tensor([1.]) class DummyNumpyMetric(NumpyMetric): @@ -139,3 +143,46 @@ def change_and_check_device_dtype(device, dtype): metric.half() assert metric.dtype == torch.float16 assert metric(input1, input2).dtype == torch.float16 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric]) +def test_model_pickable(tmpdir, metric: Metric): + """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + + model = EvalModelTemplate() + model.metric = metric() + model.training_step = model.training_step__using_metrics + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # correct result and ok accuracy + assert result == 1, 'ddp model failed to complete' + + +@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()]) +def test_saving_pickable(tmpdir, metric: Metric): + """ Make sure that metrics are pickable by saving and loading them using torch """ + x, y = torch.randn(10,), torch.randn(10,) + results_before_save = metric(x, y) + + # save metric + save_path = os.path.join(tmpdir, 'save_test.ckpt') + torch.save(metric, save_path) + + # load metric + new_metric = torch.load(save_path) + results_after_load = new_metric(x, y) + + # Check metric value is the same + assert results_before_save == results_after_load diff --git a/tests/metrics/test_sklearn.py b/tests/metrics/test_sklearn.py index bef5a4ffe0ab0..10b57417411c4 100644 --- a/tests/metrics/test_sklearn.py +++ b/tests/metrics/test_sklearn.py @@ -33,7 +33,7 @@ jaccard_score as sk_jaccard_score ) -from pytorch_lightning.metrics.converters import _convert_to_numpy +from pytorch_lightning.metrics.converters import convert_to_numpy from pytorch_lightning.metrics.sklearns import ( Accuracy, AUC, @@ -163,17 +163,17 @@ def new_func(*args, **kwargs): id='Jaccard') ]) def test_sklearn_metric(metric_class, sklearn_func, inputs): - numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) sklearn_result = sklearn_func(**numpy_inputs) lightning_result = metric_class(**inputs) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) sklearn_result = apply_to_collection( - sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) lightning_result = apply_to_collection( - lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) assert isinstance(lightning_result, type(sklearn_result))