Skip to content

Commit

Permalink
refactor and added hook
Browse files Browse the repository at this point in the history
variant a


variant b


add test


revert rename


add changelog


docs
  • Loading branch information
awaelchli committed May 14, 2020
1 parent c05077f commit cbdefc8
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797))

- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)).

### Changed

- Enable `non-blocking` for device transfers to GPU ([#1843](https://github.com/PyTorchLightning/pytorch-lightning/pull/1843))
Expand Down
38 changes: 38 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,41 @@ def backward(self, use_amp, loss, optimizer):
scaled_loss.backward()
else:
loss.backward()

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
Lightning only calls the hook if it does not recognize the data type of your batch as one of
- :class:`torch.Tensor`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
These data types (and any arbitrary nesting of them) are supported out of the box.
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
"""
78 changes: 54 additions & 24 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@
import time
import random
import torch
from typing import Union
from typing import Union, Any

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
Expand Down Expand Up @@ -434,52 +434,82 @@ def copy_trainer_model_properties(self, model):
m.tpu_global_core_rank = self.tpu_global_core_rank
m._device = self._device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
def transfer_batch_to_tpu(self, batch: Any):
device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu')
return self.__transfer_data_to_device(batch, device)

def transfer_batch_to_gpu(self, batch, gpu_id):
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)
def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):
device = torch.device('cuda', gpu_id)
return self.__transfer_data_to_device(batch, device)

def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(xm.xla_device())
def __transfer_data_to_device(self, batch: Any, device: torch.device):
if callable(getattr(batch, 'to', None)):
return batch.to(device)

if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
if callable(getattr(batch, 'cuda', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.cuda(gpu_id, non_blocking=True)
# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device)
return batch

# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device)
return tuple(batch)

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device)

return batch

# check if the model hook can move the data
model = self.get_model()
if model is not None and self.is_overridden('transfer_batch_to_device', model):
batch = model.transfer_batch_to_device(batch, device)

# nothing matches, return the value as is without transform
return batch

def __transfer_data_to_device(self, batch: Any, device: torch.device):

if self.is_overriden('transfer_batch_to_device'):
return self.get_model().transfer_batch_to_device(batch, device)

if callable(getattr(batch, 'to', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(device, non_blocking=True)

# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
batch[i] = self.__transfer_data_to_device(x, device)
return batch

# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
batch[i] = self.__transfer_data_to_device(x, device)
return tuple(batch)

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)
batch[k] = self.__transfer_data_to_device(v, device)

return batch

Expand Down
32 changes: 32 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import MagicMock

import pytest
import torch

import tests.base.utils as tutils
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -27,3 +30,32 @@ def on_before_zero_grad(self, optimizer):
model.on_before_zero_grad_called = 0
trainer.test(model)
assert 0 == model.on_before_zero_grad_called


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():

class CustomBatch:

def __init__(self, data):
self.samples = data[0]
self.targets = data[1]

class CurrentTestModel(EvalModelTemplate):

def transfer_batch_to_device(self, batch, device):
if isinstance(batch, CustomBatch):
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch

model = CurrentTestModel(tutils.get_default_hparams())
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)

batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
device = torch.device('cuda', 0)
assert batch_gpu.samples.device == batch_gpu.targets.device == device

0 comments on commit cbdefc8

Please sign in to comment.