Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lost compatibility with custom datatypes implementing .to #2335

Merged
merged 6 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))

- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

- :class:`torch.Tensor`
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
Expand Down
44 changes: 37 additions & 7 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from collections import Mapping, Sequence
from typing import Any, Callable, Union

Expand Down Expand Up @@ -38,14 +39,43 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return data


class TransferableDataType(ABC):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case someone has a suggestion for a better name, let me know :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
A custom type for data that can be moved to a torch device via `.to(...)`.

Example:

>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
... def __init__(self):
... self.x = torch.rand(2, 2)
... def to(self, device):
... self.x = self.x.to(device)
... return self
>>> isinstance(CustomObject(), TransferableDataType)
True
"""

@classmethod
def __subclasshook__(cls, subclass):
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented


def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.

Args:
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
for a list of supported collection types.
device: The device to which tensors should be moved
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved

Return:
the same collection but with all contained tensors residing on the new device.
Expand All @@ -54,6 +84,6 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(tensor):
return tensor.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
def to(data):
return data.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
12 changes: 12 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,15 @@ def test_single_gpu_batch_parse():
batch = trainer.transfer_batch_to_gpu(batch, 0)
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'

# non-Tensor that has `.to()` defined
class CustomBatchType:
def __init__(self):
self.a = torch.rand(2, 2)

def to(self, *args, **kwargs):
self.a = self.a.to(*args, **kwargs)
return self

batch = trainer.transfer_batch_to_gpu(CustomBatchType())
assert batch.a.type() == 'torch.cuda.FloatTensor'