Skip to content

Commit

Permalink
Fix lost compatibility with custom datatypes implementing .to (#2335)
Browse files Browse the repository at this point in the history
* generalize data transfer

* added test

* update docs

* fix spelling error

* changelog

* update docs
  • Loading branch information
awaelchli authored Jun 24, 2020
1 parent 598f514 commit aab9e77
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))

- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor`
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
Expand Down
44 changes: 37 additions & 7 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from collections import Mapping, Sequence
from typing import Any, Callable, Union

Expand Down Expand Up @@ -38,14 +39,43 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return data


class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Example:
>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
... def __init__(self):
... self.x = torch.rand(2, 2)
... def to(self, device):
... self.x = self.x.to(device)
... return self
>>> isinstance(CustomObject(), TransferableDataType)
True
"""

@classmethod
def __subclasshook__(cls, subclass):
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented


def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
Args:
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
for a list of supported collection types.
device: The device to which tensors should be moved
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved
Return:
the same collection but with all contained tensors residing on the new device.
Expand All @@ -54,6 +84,6 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(tensor):
return tensor.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
def to(data):
return data.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
12 changes: 12 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,15 @@ def test_single_gpu_batch_parse():
batch = trainer.transfer_batch_to_gpu(batch, 0)
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'

# non-Tensor that has `.to()` defined
class CustomBatchType:
def __init__(self):
self.a = torch.rand(2, 2)

def to(self, *args, **kwargs):
self.a = self.a.to(*args, **kwargs)
return self

batch = trainer.transfer_batch_to_gpu(CustomBatchType())
assert batch.a.type() == 'torch.cuda.FloatTensor'

0 comments on commit aab9e77

Please sign in to comment.