diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index c7d5c2cc22dd01..6d35b2f0fb9748 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -155,4 +155,39 @@ def backward(self, use_amp, loss, optimizer): loss.backward() def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - pass + """ + Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors + wrapped in a custom data structure. + Lightning only calls the hook if it does not recognize the data type of your batch as one of + + - :class:`torch.Tensor` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - ``torchtext.data.Batch`` (COMING SOON) + + These data types (and any arbitrary nesting of them) are supported out of the box. + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Example:: + + def transfer_batch_to_device(self, batch, device) + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the + batch and determines the target devices. + """