From 87f5e56d58071d74fe1a6a975b6f1d4589e647a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 05:56:51 +0200 Subject: [PATCH 1/6] add docs --- pytorch_lightning/core/hooks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1695e090f031b..2bf9c18cf7593 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). - The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the - batch and determines the target devices. + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` From 82a3b641cafe9d7e5907cf45b99f891620f62696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 06:00:18 +0200 Subject: [PATCH 2/6] non blocking only on tensor --- pytorch_lightning/utilities/apply_func.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59b73f0fced3c..99b635d0db8a4 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -104,6 +104,7 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - return data.to(device, non_blocking=True) + kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} + return data.to(device, **kwargs) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) From 9eeb4eb5d1d9f08f8d0559d9eb9e010f6e00fcca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 06:02:04 +0200 Subject: [PATCH 3/6] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c05011edd1b69..b2801621d7d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828)) +- Fixed passing `non_blocking=True` when transferring a batch object that does not support it. ([# ](https://github.com/PyTorchLightning/pytorch-lightning/pull/ )) + ## [0.8.5] - 2020-07-09 ### Added From 4b29ea71cbeed05b6ff614e6ee346784f55f0278 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 10:51:05 +0200 Subject: [PATCH 4/6] add test case --- tests/models/test_gpu.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 7497a53083612..9da2313ea717c 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -1,4 +1,5 @@ from collections import namedtuple +from unittest.mock import patch import pytest import torch @@ -384,3 +385,23 @@ def to(self, *args, **kwargs): assert batch.text.type() == 'torch.cuda.LongTensor' assert batch.label.type() == 'torch.cuda.LongTensor' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_non_blocking(): + trainer = Trainer() + + batch = torch.zeros(2, 3) + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True) + + class BatchObject(object): + + def to(self, *args, **kwargs): + pass + + batch = BatchObject() + with patch.object(batch, 'to', wraps=batch.to) as mocked: + trainer.transfer_batch_to_gpu(batch, 0) + mocked.assert_called_with(torch.device('cuda', 0)) From 17d370011c7e9141bc9acbd075d4951543a4fd82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 10:52:50 +0200 Subject: [PATCH 5/6] add test comment --- tests/models/test_gpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 9da2313ea717c..509de4c07563a 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -389,6 +389,7 @@ def to(self, *args, **kwargs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_non_blocking(): + """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ trainer = Trainer() batch = torch.zeros(2, 3) From 44cea11623de1a8a2b43fd26f3618d31582ecdf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 10:54:36 +0200 Subject: [PATCH 6/6] update changelog changelog chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2801621d7d0c..8232f544c8b9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,7 +123,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828)) -- Fixed passing `non_blocking=True` when transferring a batch object that does not support it. ([# ](https://github.com/PyTorchLightning/pytorch-lightning/pull/ )) +- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910)) ## [0.8.5] - 2020-07-09