Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Sklearn metric #1320

Closed
wants to merge 14 commits into from
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Metrics Package (will be added to unreleased, once it's finished)
- Added base-metric ([#1293](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))

## [unreleased] - YYYY-MM-DD

### Added
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
- check-manifest
- twine==1.13.0
- pillow<7.0.0
- scikit-learn>=0.16.1

- pip:
- test-tube>=0.7.5
Expand Down
Empty file.
234 changes: 234 additions & 0 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import numbers
from abc import ABC, abstractmethod
from typing import Union, Any, Optional

import torch
import torch.distributed

from pytorch_lightning.utilities.apply_to_collection import apply_to_collection

__all__ = ['BaseMetric']


class BaseMetric(torch.nn.Module, ABC):
def __init__(self, name: str,
reduce_group: Optional[Any] = torch.distributed.group.WORLD,
reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM):
"""
Abstract Base Class for metric implementation.

Automatically handles the computation
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__()
self.name = name
self.reduce_op = reduce_op
self.reduce_group = reduce_group
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')

@property
def dtype(self):
return self._dtype

@property
def device(self):
return self._device

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Implements the actual metric computation.

Returns:
metric value

"""
raise NotImplementedError

def __call__(self, *args, **kwargs) -> torch.Tensor:
return apply_to_collection(
super().__call__(*args, **kwargs),
torch.Tensor, _sync_ddp_to_device_type,
device=self.device, dtype=self.dtype,
group=self.reduce_group,
reduce_op=self.reduce_op)

def to(self, *args, **kwargs):
"""Moves and/or casts the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False)

.. function:: to(dtype, non_blocking=False)

.. function:: to(tensor, non_blocking=False)

Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.

See below for examples.

.. note::
This method modifies the module in-place.

Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module

Returns:
Module: self

Example::

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)

"""
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self._device = device

if dtype is not None:
self._dtype = dtype

return super().to(*args, **kwargs)

def cuda(self, device=None):
"""Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.

Arguments:
device (int, optional): if specified, all parameters will be
copied to that device

Returns:
Module:
"""

self._device = torch.device('cuda', index=device)
return super().cuda(device=device)

def cpu(self):
"""Moves all model parameters and buffers to the CPU.

Returns:
Module: self
"""
self._device = torch.device('cpu')
return super().cpu()

def type(self, dst_type):
"""Casts all parameters and buffers to :attr:`dst_type`.

Arguments:
dst_type (type or string): the desired type

Returns:
Module: self
"""
self._dtype = dst_type
return super().type(dst_type=dst_type)

def float(self):
"""Casts all floating point parameters and buffers to float datatype.

Returns:
Module: self
"""
self._dtype = torch.float
return super().float()

def double(self):
"""Casts all floating point parameters and buffers to ``double`` datatype.

Returns:
Module: self
"""
self._dtype = torch.double
return super().double()

def half(self):
"""Casts all floating point parameters and buffers to ``half`` datatype.

Returns:
Module: self
"""
self._dtype = torch.half
return super().half()


def _sync_ddp_to_device_type(result: Union[torch.Tensor, numbers.Number],
device: Union[str, torch.device],
dtype: Union[str, torch.dtype],
group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Args:
result: the value to sync and reduce (typically tensor or number)
device: the device to put the synced and reduced value to
dtype: the datatype to convert the synced and reduced value to
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
reduced value

"""

# convert to tensor if necessary
if not isinstance(result, torch.Tensor):
result = torch.tensor(result)

if torch.distributed.is_available() and torch.distributed.is_initialized():
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

return result.to(device=device, dtype=dtype)
93 changes: 93 additions & 0 deletions pytorch_lightning/metrics/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numbers
from typing import Any, Union

import numpy as np

import torch
from torch.utils.data._utils.collate import default_convert

from pytorch_lightning import _logger as lightning_logger
from pytorch_lightning.metrics.metric import BaseMetric
from pytorch_lightning.utilities.apply_to_collection import apply_to_collection


class SklearnMetric(BaseMetric):
def __init__(self, metric_name: str,
reduce_group: Any = torch.distributed.group.WORLD,
reduce_op: Any = torch.distributed.ReduceOp.SUM, **kwargs):
"""
Bridge between PyTorch Lightning and scikit-learn metrics

.. warning::
Every metric call will cause a GPU synchronization, which may slow down your code

Args:
metric_name: the metric name to import anc compute from scikit-learn.metrics
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
**kwargs: additonal keyword arguments (will be forwarded to metric call)
"""
super().__init__(name=metric_name, reduce_group=reduce_group,
reduce_op=reduce_op)

self.metric_kwargs = kwargs

lightning_logger.debug(
'Every metric call will cause a GPU synchronization, which may slow down your code')

@property
def metric_fn(self):
import sklearn.metrics
return getattr(sklearn.metrics, self.name)

def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Carries the actual metric computation and therefore co
Args:
*args: Positional arguments forwarded to metric call
**kwargs: keyword arguments forwarded to metric call

Returns:
the metric value

"""
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
args = apply_to_collection(args, (torch.Tensor, np.ndarray), _convert_to_numpy)
kwargs = apply_to_collection(kwargs, (torch.Tensor, np.ndarray), _convert_to_numpy)

return _convert_to_tensor(self.metric_fn(*args, **kwargs, **self.metric_kwargs))


def _convert_to_tensor(data: Any) -> Any:
"""
Maps all kind of collections and numbers to tensors

Args:
data: the data to convert to tensor

Returns:
the converted data

"""
if isinstance(data, numbers.Number):
return torch.tensor([data])

else:
return default_convert(data)


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
"""
converts all tensors and numpy arrays to numpy arrays
Args:
data: the tensor or array to convert to numpy

Returns:
the resulting numpy array

"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
return data
37 changes: 37 additions & 0 deletions pytorch_lightning/utilities/apply_to_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections import Mapping, Sequence
from typing import Any, Callable, Union


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.


Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
the resulting collection

"""
elem_type = type(data)

# Breaking condition
if isinstance(data, dtype):
return function(data, *args, **kwargs)

# Recursively apply to collection items
elif isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(data, dtype, function, *args, **kwargs)))
elif isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
return data
1 change: 1 addition & 0 deletions requirements-extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mlflow>=1.0.0
test_tube>=0.7.5
wandb>=0.8.21
trains>=0.14.1
scikit-learn>=0.16.1
Empty file added tests/metrics/__init__.py
Empty file.
Loading