From 6efccda7b56cf27862328c9b7aa9e644562cf62e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 12 Aug 2020 01:28:37 +0200 Subject: [PATCH] Do not pass non_blocking=True if it does not support this argument (#2910) * add docs * non blocking only on tensor * changelog * add test case * add test comment * update changelog changelog chlog --- CHANGELOG.md | 2 ++ pytorch_lightning/core/hooks.py | 9 +++++++-- pytorch_lightning/utilities/apply_func.py | 3 ++- tests/models/test_gpu.py | 22 ++++++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c05011edd1b690..8232f544c8b9da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828)) +- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1695e090f031b9..2bf9c18cf7593b 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). - The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the - batch and determines the target devices. + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59b73f0fced3c3..99b635d0db8a42 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -104,6 +104,7 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - return data.to(device, non_blocking=True) + kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} + return data.to(device, **kwargs) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 7497a53083612e..509de4c07563a2 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -1,4 +1,5 @@ from collections import namedtuple +from unittest.mock import patch import pytest import torch @@ -384,3 +385,24 @@ def to(self, *args, **kwargs): assert batch.text.type() == 'torch.cuda.LongTensor' assert batch.label.type() == 'torch.cuda.LongTensor' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_non_blocking(): + """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ + trainer = Trainer() + + batch = torch.zeros(2, 3) + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True) + + class BatchObject(object): + + def to(self, *args, **kwargs): + pass + + batch = BatchObject() + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0))