From 3047bfe15ecb380e2eda3dca1402d0098e0028e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 05:46:17 +0200 Subject: [PATCH] fix non blocking + docs --- pytorch_lightning/core/hooks.py | 9 +++++++-- pytorch_lightning/utilities/apply_func.py | 8 +++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1695e090f031b9..2bf9c18cf7593b 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). - The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the - batch and determines the target devices. + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59b73f0fced3c3..19673b8ecca6f5 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,4 +1,5 @@ import importlib +import inspect from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -104,6 +105,11 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - return data.to(device, non_blocking=True) + kwargs = dict() + signature = inspect.signature(data.to) + if "non_blocking" in signature.parameters: + kwargs.update(non_blocking=True) + + return data.to(device, **kwargs) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)