Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 9, 2020
1 parent 7361fb0 commit 06c1ac2
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,39 @@ def backward(self, use_amp, loss, optimizer):
loss.backward()

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
Lightning only calls the hook if it does not recognize the data type of your batch as one of
- :class:`torch.Tensor`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
These data types (and any arbitrary nesting of them) are supported out of the box.
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
"""

0 comments on commit 06c1ac2

Please sign in to comment.