diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1695e090f031b9..2bf9c18cf7593b 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). - The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the - batch and determines the target devices. + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59b73f0fced3c3..19673b8ecca6f5 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,4 +1,5 @@ import importlib +import inspect from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -104,6 +105,11 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - return data.to(device, non_blocking=True) + kwargs = dict() + signature = inspect.signature(data.to) + if "non_blocking" in signature.parameters: + kwargs.update(non_blocking=True) + + return data.to(device, **kwargs) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)