From 69d241c82e10cf40e5787fb39bb808687d693b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 Aug 2020 01:28:37 +0200 Subject: [PATCH] Do not pass non_blocking=True if it does not support this argument (#2910) * add docs * non blocking only on tensor * changelog * add test case * add test comment * update changelog changelog chlog --- CHANGELOG.md | 2 ++ pytorch_lightning/core/hooks.py | 9 +++++++-- pytorch_lightning/utilities/apply_func.py | 3 ++- tests/models/test_gpu.py | 22 ++++++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c05011edd1b69..8232f544c8b9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828)) +- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1695e090f031b..2bf9c18cf7593 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). - The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the - batch and determines the target devices. + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59b73f0fced3c..99b635d0db8a4 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -104,6 +104,7 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - return data.to(device, non_blocking=True) + kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} + return data.to(device, **kwargs) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 7497a53083612..509de4c07563a 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -1,4 +1,5 @@ from collections import namedtuple +from unittest.mock import patch import pytest import torch @@ -384,3 +385,24 @@ def to(self, *args, **kwargs): assert batch.text.type() == 'torch.cuda.LongTensor' assert batch.label.type() == 'torch.cuda.LongTensor' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_non_blocking(): + """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ + trainer = Trainer() + + batch = torch.zeros(2, 3) + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True) + + class BatchObject(object): + + def to(self, *args, **kwargs): + pass + + batch = BatchObject() + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0))