Skip to content

Commit

Permalink
data transfer model hook (+ refactor) (#1756)
Browse files Browse the repository at this point in the history
* refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs

* resolve merge duplication

* overridden typo

* fix test

* tpu id

* raise if TPU not available

* re-use apply_to_collection function for parsing collections

* comment

* make utility function available to user

* documentation

* move changelog entry to top

* fix tpu transfer call

* fix call

* remove hardcoded string

* improve test

* call model hook by default

* Apply suggestions from code review

* rename utility function

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and justusschock committed Jun 29, 2020
1 parent 74260a7 commit 9152383
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))

- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)).

### Changed

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
'Trainer',
'LightningModule',
'Callback',
'data_loader'
'seed_everything'
'data_loader',
'seed_everything',
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
Expand Down
47 changes: 47 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device


try:
from apex import amp
Expand Down Expand Up @@ -153,3 +155,48 @@ def backward(self, use_amp, loss, optimizer):
scaled_loss.backward()
else:
loss.backward()

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return move_data_to_device(batch, device)
97 changes: 45 additions & 52 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down Expand Up @@ -99,58 +100,50 @@ def copy_trainer_model_properties(self, model):
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')

def transfer_batch_to_gpu(self, batch, gpu_id):
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)

def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
return batch.to(xla_device)

if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
if callable(getattr(batch, 'cuda', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.cuda(gpu_id, non_blocking=True)

if callable(getattr(batch, 'to', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)

# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return batch

# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)

return batch

# nothing matches, return the value as is without transform
return batch
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
"""
Transfers the data to the TPU.
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(tpu_id)
return self.__transfer_batch_to_device(batch, device)

def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
"""
Transfers the data to the GPU.
Args:
batch: A tensor or collection of tensors.
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.
Return:
the tensor on the GPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
device = torch.device('cuda', gpu_id)
return self.__transfer_batch_to_device(batch, device)

def __transfer_batch_to_device(self, batch: Any, device: torch.device):
model = self.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:

# TPU data transfer
if self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch

# CPU, TPU or gpu step
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):

# TPU support
elif self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
output = self.model.training_step(*args)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""General utilities"""

from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
23 changes: 23 additions & 0 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import Mapping, Sequence
from typing import Any, Callable, Union, Optional

import torch


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
Expand Down Expand Up @@ -37,3 +39,24 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable

# data is neither of dtype, nor a collection
return data


def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Args:
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
for a list of supported collection types.
device: The device to which tensors should be moved
Return:
the same collection but with all contained tensors residing on the new device.
See Also:
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(tensor):
return tensor.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
36 changes: 36 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import MagicMock

import pytest
import torch

Expand Down Expand Up @@ -68,3 +70,37 @@ def training_epoch_end(self, outputs):
# metrics are kept after each epoch
for i in range(num_epochs):
assert metrics[f'epoch_metric_{i}'] == i


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():

class CustomBatch:

def __init__(self, data):
self.samples = data[0]
self.targets = data[1]

class CurrentTestModel(EvalModelTemplate):

hook_called = False

def transfer_batch_to_device(self, data, device):
self.hook_called = True
if isinstance(data, CustomBatch):
data.samples = data.samples.to(device)
data.targets = data.targets.to(device)
else:
data = super().transfer_batch_to_device(data, device)
return data

model = CurrentTestModel()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
expected = torch.device('cuda', 0)
assert model.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected

0 comments on commit 9152383

Please sign in to comment.